From ca07c081744c559858a4b460fd9243e51b05e5b3 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 13:59:57 -0800 Subject: [PATCH 01/81] first commit (unclean) --- cpp/include/cuvs/cluster/kmeans.hpp | 143 +++++- cpp/src/cluster/detail/kmeans_batched.cuh | 536 ++++++++++++++++++++++ cpp/src/cluster/kmeans_fit_double.cu | 44 +- cpp/src/cluster/kmeans_fit_float.cu | 44 +- 4 files changed, 764 insertions(+), 3 deletions(-) create mode 100644 cpp/src/cluster/detail/kmeans_batched.cuh diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a8aa6b9807..3b827630ab 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -134,6 +134,147 @@ struct balanced_params : base_params { * @{ */ +/** + * @defgroup kmeans_batched Batched k-means for out-of-core / host data + * @{ + */ + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * This version supports out-of-core computation where the dataset resides + * on the host. Data is processed in batches, with partial sums accumulated + * across batches and centroids finalized at the end of each iteration. + * This is mathematically equivalent to standard kmeans. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15; + * float inertia; + * int n_iter; + * + * // Data on host + * std::vector h_X(n_samples * n_features); + * auto X = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + * + * // Centroids on device + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit_batched(handle, + * params, + * X, + * 100000, // batch_size + * std::nullopt, + * centroids.view(), + * raft::make_host_scalar_view(&inertia), + * raft::make_host_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @} + */ + /** * @brief Find clusters with k-means algorithm. * Initial centroids are chosen with k-means++ algorithm. Empty diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh new file mode 100644 index 0000000000..63e36f2f5b --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -0,0 +1,536 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "kmeans_common.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::batched::detail { + +/** + * @brief Sample data from host to device for initialization + * + * Samples `n_samples_to_gather` rows from host data and copies to device. + * Uses uniform strided sampling for simplicity and cache-friendliness. + */ +template +void prepare_init_sample(raft::resources const& handle, + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_samples_out = X_sample.extent(0); + + // Use strided sampling for cache-friendliness + // For truly random, could use std::shuffle on indices first + std::mt19937 gen(seed); + std::vector indices(n_samples); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), gen); + + std::vector host_sample(n_samples_out * n_features); + +#pragma omp parallel for + for (IndexT i = 0; i < static_cast(n_samples_out); i++) { + IndexT src_idx = indices[i]; + std::memcpy(host_sample.data() + i * n_features, + X.data_handle() + src_idx * n_features, + n_features * sizeof(DataT)); + } + + raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); +} + +/** + * @brief Initialize centroids using k-means++ on a sample of the host data + */ +template +void init_centroids_from_host_sample(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + // Sample size for initialization: at least 3 * n_clusters, but not more than n_samples + size_t init_sample_size = std::min(static_cast(n_samples), + std::max(static_cast(3 * n_clusters), + static_cast(10000))); + + RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); + + // Sample data from host to device + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); + + // Run k-means++ on the sample + if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + cuvs::cluster::kmeans::detail::kmeansPlusPlus( + handle, + params, + raft::make_device_matrix_view( + init_sample.data_handle(), init_sample_size, n_features), + centroids, + workspace); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + // Just use the first n_clusters samples + raft::copy(centroids.data_handle(), + init_sample.data_handle(), + n_clusters * n_features, + stream); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + // Centroids already provided, nothing to do + } else { + RAFT_FAIL("Unknown initialization method"); + } +} + +/** + * @brief Accumulate partial centroid sums and counts from a batch + * + * This function adds the partial sums from a batch to the running accumulators. + * It does NOT divide - that happens once at the end of all batches. + */ +template +void accumulate_batch_centroids( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = batch_data.extent(0); + auto n_features = batch_data.extent(1); + auto n_clusters = centroid_sums.extent(0); + + // Temporary buffers for this batch's partial results + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + // Zero the batch temporaries + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_sums.data_handle(), + batch_sums.data_handle() + batch_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_counts.data_handle(), + batch_counts.data_handle() + batch_counts.size(), + DataT{0}); + + // Extract cluster labels from KeyValuePair + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + workspace.resize(n_samples, stream); + + // Compute weighted sum of samples per cluster for this batch + raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), + batch_data.extent(1), + labels_itr, + sample_weights.data_handle(), + workspace.data(), + batch_data.extent(0), + batch_data.extent(1), + n_clusters, + batch_sums.data_handle(), + stream); + + // Compute sum of weights per cluster for this batch + raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), + labels_itr, + batch_counts.data_handle(), + static_cast(1), + static_cast(n_samples), + static_cast(n_clusters), + stream); + + // Add batch results to running accumulators + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(cluster_counts.data_handle(), + cluster_counts.data_handle(), + batch_counts.data_handle(), + cluster_counts.size(), + stream); +} + +/** + * @brief Finalize centroids by dividing accumulated sums by counts + */ +template +void finalize_centroids(raft::resources const& handle, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_clusters = new_centroids.extent(0); + auto n_features = new_centroids.extent(1); + + // Copy sums to new_centroids first + raft::copy( + new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); + + // Divide by counts: new_centroids[i] = centroid_sums[i] / cluster_counts[i] + // When count is 0, set to 0 (will be fixed below) + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(new_centroids), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); + + // Copy old centroids to new centroids where cluster_counts[i] == 0 + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + raft::matrix::gather_if( + old_centroids.data_handle(), + static_cast(old_centroids.extent(1)), + static_cast(old_centroids.extent(0)), + itr_wt, + itr_wt, + static_cast(cluster_counts.size()), + new_centroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { + return map.value == DataT{0}; // predicate: copy when count is 0 + }, + raft::key_op{}, + stream); +} + +/** + * @brief Main fit function for batched k-means with host data + * + * @tparam DataT Data type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] params K-means parameters + * @param[in] X Input data on HOST [n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch + * @param[in] sample_weight Optional weights per sample (on host) + * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid + * @param[out] n_iter Number of iterations run + */ +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IndexT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, + "centroids.extent(1) must equal n_features"); + + raft::default_logger().set_level(params.verbosity); + + RAFT_LOG_DEBUG( + "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(batch_size)); + + rmm::device_uvector workspace(0, stream); + + // Initialize centroids from a sample of host data + if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + init_centroids_from_host_sample(handle, params, X, centroids, workspace); + } + + // Allocate device buffers + // Batch buffer for data + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + // Batch buffer for weights + auto batch_weights = raft::make_device_vector(handle, batch_size); + // Cluster assignment for batch + auto minClusterAndDistance = + raft::make_device_vector, IndexT>(handle, batch_size); + // L2 norms of batch data + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + // Temporary buffer for distance computation + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // Accumulators for centroid computation (persist across batches within an iteration) + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // Host buffer for batch data (pinned memory for faster H2D transfer) + std::vector host_batch_buffer(batch_size * n_features); + std::vector host_weight_buffer(batch_size); + + // Cluster cost for convergence check + rmm::device_scalar clusterCostD(stream); + DataT priorClusteringCost = 0; + + // Main iteration loop + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); + + // Zero accumulators at start of each iteration + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + + DataT total_cost = 0; + + // Process all data in batches + for (IndexT offset = 0; offset < n_samples; offset += batch_size) { + IndexT current_batch_size = std::min(batch_size, n_samples - offset); + + // Copy batch data from host to device + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + + // Copy or set weights for this batch + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + offset, + current_batch_size, + stream); + } else { + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_weights.data_handle(), + batch_weights.data_handle() + current_batch_size, + DataT{1}); + } + + // Create views for current batch size + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + auto L2NormBatch_view = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + // Compute L2 norms for batch if needed + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); + } + + // Find nearest centroid for each sample in batch + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Accumulate partial sums for this batch + auto minClusterAndDistance_const = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view(), + workspace); + + // Accumulate cluster cost if checking convergence + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + DataT batch_cost = clusterCostD.value(stream); + total_cost += batch_cost; + } + } // end batch loop + + // Finalize centroids: divide sums by counts + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + + // Compute squared norm of change in centroids + auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + new_centroids.size(), + raft::sqdiff_op{}, + stream, + centroids.data_handle(), + new_centroids.data_handle()); + + DataT sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + + // Update centroids + raft::copy(centroids.data_handle(), new_centroids.data_handle(), new_centroids.size(), stream); + + // Check convergence + bool done = false; + if (params.inertia_check) { + if (n_iter[0] > 1) { + DataT delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } + priorClusteringCost = total_cost; + } + + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; + + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; + } + } // end iteration loop + + // Compute final inertia by processing all data once more + inertia[0] = 0; + for (IndexT offset = 0; offset < n_samples; offset += batch_size) { + IndexT current_batch_size = std::min(batch_size, n_samples - offset); + + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); + } + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + + inertia[0] += clusterCostD.value(stream); + } + + RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", inertia[0]); +} + +} // namespace cuvs::cluster::kmeans::batched::detail + diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 43f457a29a..0962c87890 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) INSTANTIATE_FIT(double, int) INSTANTIATE_FIT(double, int64_t) +INSTANTIATE_FIT_BATCHED(double, int) +INSTANTIATE_FIT_BATCHED(double, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 5624151943..39cc074d9e 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) INSTANTIATE_FIT(float, int) INSTANTIATE_FIT(float, int64_t) +INSTANTIATE_FIT_BATCHED(float, int) +INSTANTIATE_FIT_BATCHED(float, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, From f1a19dfd29ea1f5ee32c47814c580628fc3cc2ec Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 17:59:35 -0800 Subject: [PATCH 02/81] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 140 +++++++++++----------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 63e36f2f5b..1999d80c69 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -47,29 +47,27 @@ namespace cuvs::cluster::kmeans::batched::detail { * Samples `n_samples_to_gather` rows from host data and copies to device. * Uses uniform strided sampling for simplicity and cache-friendliness. */ -template +template void prepare_init_sample(raft::resources const& handle, - raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed) + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_samples_out = X_sample.extent(0); - // Use strided sampling for cache-friendliness - // For truly random, could use std::shuffle on indices first std::mt19937 gen(seed); - std::vector indices(n_samples); + std::vector indices(n_samples); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), gen); std::vector host_sample(n_samples_out * n_features); #pragma omp parallel for - for (IndexT i = 0; i < static_cast(n_samples_out); i++) { - IndexT src_idx = indices[i]; + for (IdxT i = 0; i < static_cast(n_samples_out); i++) { + IdxT src_idx = indices[i]; std::memcpy(host_sample.data() + i * n_features, X.data_handle() + src_idx * n_features, n_features * sizeof(DataT)); @@ -81,11 +79,11 @@ void prepare_init_sample(raft::resources const& handle, /** * @brief Initialize centroids using k-means++ on a sample of the host data */ -template +template void init_centroids_from_host_sample(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - raft::device_matrix_view centroids, + raft::host_matrix_view X, + raft::device_matrix_view centroids, rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -101,15 +99,15 @@ void init_centroids_from_host_sample(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); // Sample data from host to device - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); // Run k-means++ on the sample if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( + cuvs::cluster::kmeans::detail::kmeansPlusPlus( handle, params, - raft::make_device_matrix_view( + raft::make_device_matrix_view( init_sample.data_handle(), init_sample_size, n_features), centroids, workspace); @@ -132,14 +130,14 @@ void init_centroids_from_host_sample(raft::resources const& handle, * This function adds the partial sums from a batch to the running accumulators. * It does NOT divide - that happens once at the end of all batches. */ -template +template void accumulate_batch_centroids( raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, + raft::device_matrix_view batch_data, + raft::device_vector_view, IdxT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -148,8 +146,8 @@ void accumulate_batch_centroids( auto n_clusters = centroid_sums.extent(0); // Temporary buffers for this batch's partial results - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); // Zero the batch temporaries thrust::fill(raft::resource::get_thrust_policy(handle), @@ -162,9 +160,9 @@ void accumulate_batch_centroids( DataT{0}); // Extract cluster labels from KeyValuePair - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); workspace.resize(n_samples, stream); @@ -185,9 +183,9 @@ void accumulate_batch_centroids( raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), labels_itr, batch_counts.data_handle(), - static_cast(1), - static_cast(n_samples), - static_cast(n_clusters), + static_cast(1), + static_cast(n_samples), + static_cast(n_clusters), stream); // Add batch results to running accumulators @@ -207,12 +205,12 @@ void accumulate_batch_centroids( /** * @brief Finalize centroids by dividing accumulated sums by counts */ -template +template void finalize_centroids(raft::resources const& handle, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_clusters = new_centroids.extent(0); @@ -252,7 +250,7 @@ void finalize_centroids(raft::resources const& handle, * @brief Main fit function for batched k-means with host data * * @tparam DataT Data type (float, double) - * @tparam IndexT Index type (int, int64_t) + * @tparam IdxT Index type (int, int64_t) * * @param[in] handle RAFT resources handle * @param[in] params K-means parameters @@ -263,15 +261,15 @@ void finalize_centroids(raft::resources const& handle, * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run */ -template +template void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IndexT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -281,7 +279,7 @@ void fit(raft::resources const& handle, RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "centroids.extent(0) must equal n_clusters"); RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); @@ -304,21 +302,21 @@ void fit(raft::resources const& handle, // Allocate device buffers // Batch buffer for data - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); // Batch buffer for weights - auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_weights = raft::make_device_vector(handle, batch_size); // Cluster assignment for batch auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, batch_size); + raft::make_device_vector, IdxT>(handle, batch_size); // L2 norms of batch data - auto L2NormBatch = raft::make_device_vector(handle, batch_size); + auto L2NormBatch = raft::make_device_vector(handle, batch_size); // Temporary buffer for distance computation rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // Accumulators for centroid computation (persist across batches within an iteration) - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto cluster_counts = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); // Host buffer for batch data (pinned memory for faster H2D transfer) std::vector host_batch_buffer(batch_size * n_features); @@ -345,8 +343,8 @@ void fit(raft::resources const& handle, DataT total_cost = 0; // Process all data in batches - for (IndexT offset = 0; offset < n_samples; offset += batch_size) { - IndexT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); // Copy batch data from host to device raft::copy(batch_data.data_handle(), @@ -368,14 +366,14 @@ void fit(raft::resources const& handle, } // Create views for current batch size - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - auto L2NormBatch_view = raft::make_device_vector_view( + auto L2NormBatch_view = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); // Compute L2 norms for batch if needed @@ -390,12 +388,12 @@ void fit(raft::resources const& handle, } // Find nearest centroid for each sample in batch - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -409,10 +407,10 @@ void fit(raft::resources const& handle, // Accumulate partial sums for this batch auto minClusterAndDistance_const = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - accumulate_batch_centroids(handle, + accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, batch_weights_view, @@ -435,14 +433,14 @@ void fit(raft::resources const& handle, } // end batch loop // Finalize centroids: divide sums by counts - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( + auto cluster_counts_const = raft::make_device_vector_view( cluster_counts.data_handle(), n_clusters); - finalize_centroids( + finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); // Compute squared norm of change in centroids @@ -481,18 +479,18 @@ void fit(raft::resources const& handle, // Compute final inertia by processing all data once more inertia[0] = 0; - for (IndexT offset = 0; offset < n_samples; offset += batch_size) { - IndexT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); raft::copy(batch_data.data_handle(), X.data_handle() + offset * n_features, current_batch_size * n_features, stream); - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -501,12 +499,12 @@ void fit(raft::resources const& handle, L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); } - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, From 0fa00b052dd2d5372f703fb5c46d0fe7f87cd484 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 19:09:38 -0800 Subject: [PATCH 03/81] copyright --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 1999d80c69..de44b309ba 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once From fcbdda59917eaaf8249355ffbd93dc91e917bcd8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:11:01 -0800 Subject: [PATCH 04/81] python test --- python/cuvs/cuvs/tests/test_kmeans.py | 80 +++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 6f18137b13..cc3b1cf4a4 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -69,3 +69,83 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [8]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "batch_samples_list", + [ + [32, 64, 128, 256, 512], # various batch sizes + ], +) +def test_kmeans_batch_size_determinism( + n_rows, n_cols, n_clusters, dtype, batch_samples_list +): + """ + Test that different batch sizes produce identical centroids. + + When starting from the same initial centroids, the k-means algorithm + should produce identical final centroids regardless of the batch_samples + parameter. This is because the accumulated adjustments to centroids after + the entire dataset pass should be the same. + """ + # Use fixed seed for reproducibility + rng = np.random.default_rng(42) + + # Generate random data + X_host = rng.random((n_rows, n_cols)).astype(dtype) + X = device_ndarray(X_host) + + # Generate fixed initial centroids (using first n_clusters rows) + initial_centroids_host = X_host[:n_clusters].copy() + + # Store results from each batch size + results = [] + + for batch_samples in batch_samples_list: + # Create fresh copy of initial centroids for each run + centroids = device_ndarray(initial_centroids_host.copy()) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", # Use provided centroids + max_iter=100, + tol=1e-10, # Very small tolerance to ensure convergence + batch_samples=batch_samples, + ) + + centroids_out, inertia, n_iter = fit(params, X, centroids) + results.append( + { + "batch_samples": batch_samples, + "centroids": centroids_out.copy_to_host(), + "inertia": inertia, + "n_iter": n_iter, + } + ) + + # Compare all results against the first one + reference = results[0] + for result in results[1:]: + # Centroids should be identical (or very close due to float precision) + assert np.allclose( + reference["centroids"], + result["centroids"], + rtol=1e-5, + atol=1e-5, + ), ( + f"Centroids differ between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}" + ) + + # Inertia should also be identical + assert np.allclose( + reference["inertia"], result["inertia"], rtol=1e-5, atol=1e-5 + ), ( + f"Inertia differs between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}: " + f"{reference['inertia']} vs {result['inertia']}" + ) From d6ed934577fffc6d24ed07053746d2411b3770ba Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:31:38 -0800 Subject: [PATCH 05/81] minibatch first commit --- c/include/cuvs/cluster/kmeans.h | 26 +++ cpp/include/cuvs/cluster/kmeans.hpp | 28 ++- cpp/src/cluster/detail/kmeans_batched.cuh | 219 +++++++++++++++++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 5 + python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 42 ++++ 5 files changed, 281 insertions(+), 39 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 0bb9591f63..79448af26f 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -36,6 +36,25 @@ typedef enum { Array = 2 } cuvsKMeansInitMethod; +/** + * @brief Centroid update mode for k-means algorithm + */ +typedef enum { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + * More accurate but requires full pass over data before each update. + */ + CUVS_KMEANS_UPDATE_FULL_BATCH = 0, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + * Faster convergence for large datasets, but may have slightly lower accuracy. + * Uses streaming/online centroid updates with learning rate decay. + */ + CUVS_KMEANS_UPDATE_MINI_BATCH = 1 +} cuvsKMeansCentroidUpdateMode; + /** * @brief Hyper-parameters for the kmeans algorithm */ @@ -90,6 +109,13 @@ struct cuvsKMeansParams { */ int batch_centroids; + /** + * Centroid update mode: + * - CUVS_KMEANS_UPDATE_FULL_BATCH: Standard Lloyd's algorithm, update after full dataset pass + * - CUVS_KMEANS_UPDATE_MINI_BATCH: Mini-batch k-means, update after each batch + */ + cuvsKMeansCentroidUpdateMode update_mode; + bool inertia_check; /** diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e141947dad..4f702d74ca 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -48,6 +48,25 @@ struct params : base_params { Array }; + /** + * Centroid update mode determines when centroids are updated during training. + */ + enum CentroidUpdateMode { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + * More accurate but requires full pass over data before each update. + */ + FullBatch, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + * Faster convergence for large datasets, but may have slightly lower accuracy. + * Uses streaming/online centroid updates with learning rate decay. + */ + MiniBatch + }; + /** * The number of clusters to form as well as the number of centroids to generate (default:8). */ @@ -104,7 +123,14 @@ struct params : base_params { /** * if 0 then batch_centroids = n_clusters */ - int batch_centroids = 0; // + int batch_centroids = 0; + + /** + * Centroid update mode: + * - FullBatch: Standard Lloyd's algorithm, update centroids after full dataset pass + * - MiniBatch: Mini-batch k-means, update centroids after each batch + */ + CentroidUpdateMode update_mode = FullBatch; bool inertia_check = false; }; diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index de44b309ba..c622c9b5a3 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -4,6 +4,7 @@ */ #pragma once +#include "kmeans.cuh" #include "kmeans_common.cuh" #include @@ -36,6 +37,8 @@ #include #include +#include +#include #include #include @@ -202,6 +205,78 @@ void accumulate_batch_centroids( stream); } +/** + * @brief Update centroids using mini-batch online learning + * + * Uses the online update formula: + * learning_rate[k] = batch_count[k] / (total_count[k] + batch_count[k]) + * centroid[k] = centroid[k] + learning_rate[k] * (batch_mean[k] - centroid[k]) + * + * This is equivalent to a weighted average where total_count tracks cumulative weight. + */ +template +void minibatch_update_centroids(raft::resources const& handle, + raft::device_matrix_view centroids, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_vector_view total_counts) +{ + auto n_clusters = centroids.extent(0); + auto n_features = centroids.extent(1); + + // Compute batch means: batch_mean = batch_sums / batch_counts + auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); + raft::copy(batch_means.data_handle(), + batch_sums.data_handle(), + batch_sums.size(), + raft::resource::get_cuda_stream(handle)); + + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + batch_counts, + batch_means.view(), + raft::div_checkzero_op{}); + + // Step 1: Update total_counts = total_counts + batch_counts + raft::linalg::add(handle, + raft::make_const_mdspan(total_counts), + batch_counts, + total_counts); + + // Step 2: Compute learning rates: lr = batch_count / total_count (after update) + auto learning_rates = raft::make_device_vector(handle, n_clusters); + raft::linalg::map(handle, + learning_rates.view(), + raft::div_checkzero_op{}, + batch_counts, + raft::make_const_mdspan(total_counts)); + + // Update centroids: centroid = centroid + lr * (batch_mean - centroid) + // = (1 - lr) * centroid + lr * batch_mean + // Using matrix_vector_op to scale each row by (1 - lr), then add lr * batch_mean + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(learning_rates.view()), + centroids, + [] __device__(DataT centroid_val, DataT lr) { return (DataT{1} - lr) * centroid_val; }); + + // Add lr * batch_mean to centroids + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + raft::make_const_mdspan(learning_rates.view()), + batch_means.view(), + [] __device__(DataT mean_val, DataT lr) { return lr * mean_val; }); + + // centroids += lr * batch_means + raft::linalg::add(handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(batch_means.view()), + centroids); +} + /** * @brief Finalize centroids by dividing accumulated sums by counts */ @@ -313,11 +388,14 @@ void fit(raft::resources const& handle, // Temporary buffer for distance computation rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // Accumulators for centroid computation (persist across batches within an iteration) + // Accumulators for centroid computation auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto cluster_counts = raft::make_device_vector(handle, n_clusters); auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + // For mini-batch mode: track total counts for learning rate calculation + auto total_counts = raft::make_device_vector(handle, n_clusters); + // Host buffer for batch data (pinned memory for faster H2D transfer) std::vector host_batch_buffer(batch_size * n_features); std::vector host_weight_buffer(batch_size); @@ -326,38 +404,83 @@ void fit(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); DataT priorClusteringCost = 0; + // Check update mode + bool use_minibatch = + (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); + + RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); + + // For mini-batch mode with random sampling, create index shuffle + std::vector sample_indices(n_samples); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::mt19937 rng(params.rng_state.seed); + // Main iteration loop for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - // Zero accumulators at start of each iteration - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + // For full-batch mode: zero accumulators at start of each iteration + // For mini-batch mode: zero total_counts at start of each iteration + if (!use_minibatch) { + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + } else { + // Mini-batch mode: zero total counts for learning rate calculation + thrust::fill(raft::resource::get_thrust_policy(handle), + total_counts.data_handle(), + total_counts.data_handle() + total_counts.size(), + DataT{0}); + // Shuffle sample indices for random batch selection + std::shuffle(sample_indices.begin(), sample_indices.end(), rng); + } + + // Save old centroids for convergence check + raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); DataT total_cost = 0; // Process all data in batches - for (IdxT offset = 0; offset < n_samples; offset += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); // Copy batch data from host to device - raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); + if (use_minibatch) { + // Mini-batch: use shuffled indices for random sampling + for (IdxT i = 0; i < current_batch_size; ++i) { + IdxT sample_idx = sample_indices[batch_idx + i]; + std::memcpy(host_batch_buffer.data() + i * n_features, + X.data_handle() + sample_idx * n_features, + n_features * sizeof(DataT)); + } + raft::copy( + batch_data.data_handle(), host_batch_buffer.data(), current_batch_size * n_features, stream); + } else { + // Full-batch: sequential access + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + } // Copy or set weights for this batch if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + offset, - current_batch_size, - stream); + if (use_minibatch) { + for (IdxT i = 0; i < current_batch_size; ++i) { + host_weight_buffer[i] = sample_weight->data_handle()[sample_indices[batch_idx + i]]; + } + raft::copy(batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); + } else { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } } else { thrust::fill(raft::resource::get_thrust_policy(handle), batch_weights.data_handle(), @@ -373,8 +496,6 @@ void fit(raft::resources const& handle, auto minClusterAndDistance_view = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - auto L2NormBatch_view = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); // Compute L2 norms for batch if needed if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -410,6 +531,18 @@ void fit(raft::resources const& handle, raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); + if (use_minibatch) { + // Mini-batch mode: zero batch accumulators before each batch + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + } + accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, @@ -418,6 +551,17 @@ void fit(raft::resources const& handle, cluster_counts.view(), workspace); + if (use_minibatch) { + // Mini-batch mode: update centroids immediately after each batch + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + minibatch_update_centroids( + handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + } + // Accumulate cluster cost if checking convergence if (params.inertia_check) { cuvs::cluster::kmeans::detail::computeClusterCost( @@ -432,32 +576,31 @@ void fit(raft::resources const& handle, } } // end batch loop - // Finalize centroids: divide sums by counts - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( - cluster_counts.data_handle(), n_clusters); + if (!use_minibatch) { + // Full-batch mode: finalize centroids after processing all batches + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); - finalize_centroids( - handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); + } - // Compute squared norm of change in centroids + // Compute squared norm of change in centroids (compare to saved old centroids) auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - new_centroids.size(), + centroids.size(), raft::sqdiff_op{}, stream, - centroids.data_handle(), - new_centroids.data_handle()); + new_centroids.data_handle(), // old centroids + centroids.data_handle()); // new centroids DataT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); - // Update centroids - raft::copy(centroids.data_handle(), new_centroids.data_handle(), new_centroids.size(), stream); - // Check convergence bool done = false; if (params.inertia_check) { diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 9f16d46c4d..d219f4e903 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -18,6 +18,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: Random Array + ctypedef enum cuvsKMeansCentroidUpdateMode: + CUVS_KMEANS_UPDATE_FULL_BATCH + CUVS_KMEANS_UPDATE_MINI_BATCH + ctypedef enum cuvsKMeansType: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED @@ -32,6 +36,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, + cuvsKMeansCentroidUpdateMode update_mode, bool inertia_check, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 489d983ac7..b8ee467fb5 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -44,6 +44,12 @@ INIT_METHOD_TYPES = { INIT_METHOD_NAMES = {v: k for k, v in INIT_METHOD_TYPES.items()} +UPDATE_MODE_TYPES = { + "full_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_FULL_BATCH, + "mini_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_MINI_BATCH} + +UPDATE_MODE_NAMES = {v: k for k, v in UPDATE_MODE_TYPES.items()} + cdef class KMeansParams: """ Hyper-parameters for the kmeans algorithm @@ -70,6 +76,20 @@ cdef class KMeansParams: Number of instance k-means algorithm will be run with different seeds oversampling_factor : double Oversampling factor for use in the k-means|| algorithm + batch_samples : int + Number of samples to process in each batch for tiled 1NN computation. + Useful to optimize/control memory footprint. Default tile is + [batch_samples x n_clusters]. + batch_centroids : int + Number of centroids to process in each batch. If 0, uses n_clusters. + update_mode : str + Centroid update strategy. One of: + "full_batch" : Standard Lloyd's algorithm - accumulate assignments over + the entire dataset, then update centroids once per iteration. + More accurate but requires full pass over data before each update. + "mini_batch" : Mini-batch k-means - update centroids after each batch. + Faster convergence for large datasets, but may have slightly lower + accuracy. Uses online centroid updates with learning rate decay. hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -92,6 +112,9 @@ cdef class KMeansParams: tol=None, n_init=None, oversampling_factor=None, + batch_samples=None, + batch_centroids=None, + update_mode=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -109,6 +132,13 @@ cdef class KMeansParams: self.params.n_init = n_init if oversampling_factor is not None: self.params.oversampling_factor = oversampling_factor + if batch_samples is not None: + self.params.batch_samples = batch_samples + if batch_centroids is not None: + self.params.batch_centroids = batch_centroids + if update_mode is not None: + c_mode = UPDATE_MODE_TYPES[update_mode] + self.params.update_mode = c_mode if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -145,6 +175,18 @@ cdef class KMeansParams: def oversampling_factor(self): return self.params.oversampling_factor + @property + def batch_samples(self): + return self.params.batch_samples + + @property + def batch_centroids(self): + return self.params.batch_centroids + + @property + def update_mode(self): + return UPDATE_MODE_NAMES[self.params.update_mode] + @property def hierarchical(self): return self.params.hierarchical From 5d4b4985b645640cbcfed514687d53ab57978216 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:42:28 -0800 Subject: [PATCH 06/81] fix docs --- c/include/cuvs/cluster/kmeans.h | 3 --- cpp/include/cuvs/cluster/kmeans.hpp | 3 --- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 3 --- 3 files changed, 9 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 79448af26f..e39b9c8523 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -43,14 +43,11 @@ typedef enum { /** * Standard k-means (Lloyd's algorithm): accumulate assignments over the * entire dataset, then update centroids once per iteration. - * More accurate but requires full pass over data before each update. */ CUVS_KMEANS_UPDATE_FULL_BATCH = 0, /** * Mini-batch k-means: update centroids after each randomly sampled batch. - * Faster convergence for large datasets, but may have slightly lower accuracy. - * Uses streaming/online centroid updates with learning rate decay. */ CUVS_KMEANS_UPDATE_MINI_BATCH = 1 } cuvsKMeansCentroidUpdateMode; diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 4f702d74ca..7ce492b136 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -55,14 +55,11 @@ struct params : base_params { /** * Standard k-means (Lloyd's algorithm): accumulate assignments over the * entire dataset, then update centroids once per iteration. - * More accurate but requires full pass over data before each update. */ FullBatch, /** * Mini-batch k-means: update centroids after each randomly sampled batch. - * Faster convergence for large datasets, but may have slightly lower accuracy. - * Uses streaming/online centroid updates with learning rate decay. */ MiniBatch }; diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b8ee467fb5..7b0cb4b3a2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -86,10 +86,7 @@ cdef class KMeansParams: Centroid update strategy. One of: "full_batch" : Standard Lloyd's algorithm - accumulate assignments over the entire dataset, then update centroids once per iteration. - More accurate but requires full pass over data before each update. "mini_batch" : Mini-batch k-means - update centroids after each batch. - Faster convergence for large datasets, but may have slightly lower - accuracy. Uses online centroid updates with learning rate decay. hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int From 72fe7892609bf6faa279381aa901034d163dcb4e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 03:36:16 -0800 Subject: [PATCH 07/81] replace thrust calls: --- cpp/src/cluster/detail/kmeans_batched.cuh | 54 +++++++---------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index c622c9b5a3..4db84b12c0 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -140,27 +140,24 @@ void accumulate_batch_centroids( raft::device_vector_view, IdxT> minClusterAndDistance, raft::device_vector_view sample_weights, raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - rmm::device_uvector& workspace) + raft::device_vector_view cluster_counts) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = batch_data.extent(0); auto n_features = batch_data.extent(1); auto n_clusters = centroid_sums.extent(0); + // Get workspace from handle + auto* workspace_resource = raft::resource::get_workspace_resource(handle); + auto workspace = rmm::device_uvector(n_samples, stream, workspace_resource); + // Temporary buffers for this batch's partial results auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); // Zero the batch temporaries - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_sums.data_handle(), - batch_sums.data_handle() + batch_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_counts.data_handle(), - batch_counts.data_handle() + batch_counts.size(), - DataT{0}); + raft::matrix::fill(handle, batch_sums.view(), DataT{0}); + raft::matrix::fill(handle, batch_counts.view(), DataT{0}); // Extract cluster labels from KeyValuePair cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; @@ -168,8 +165,6 @@ void accumulate_batch_centroids( const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); - workspace.resize(n_samples, stream); - // Compute weighted sum of samples per cluster for this batch raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), batch_data.extent(1), @@ -422,20 +417,11 @@ void fit(raft::resources const& handle, // For full-batch mode: zero accumulators at start of each iteration // For mini-batch mode: zero total_counts at start of each iteration if (!use_minibatch) { - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); } else { // Mini-batch mode: zero total counts for learning rate calculation - thrust::fill(raft::resource::get_thrust_policy(handle), - total_counts.data_handle(), - total_counts.data_handle() + total_counts.size(), - DataT{0}); + raft::matrix::fill(handle, total_counts.view(), DataT{0}); // Shuffle sample indices for random batch selection std::shuffle(sample_indices.begin(), sample_indices.end(), rng); } @@ -482,10 +468,9 @@ void fit(raft::resources const& handle, stream); } } else { - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_weights.data_handle(), - batch_weights.data_handle() + current_batch_size, - DataT{1}); + auto batch_weights_fill_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + raft::matrix::fill(handle, batch_weights_fill_view, DataT{1}); } // Create views for current batch size @@ -533,14 +518,8 @@ void fit(raft::resources const& handle, if (use_minibatch) { // Mini-batch mode: zero batch accumulators before each batch - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); } accumulate_batch_centroids(handle, @@ -548,8 +527,7 @@ void fit(raft::resources const& handle, minClusterAndDistance_const, batch_weights_view, centroid_sums.view(), - cluster_counts.view(), - workspace); + cluster_counts.view()); if (use_minibatch) { // Mini-batch mode: update centroids immediately after each batch From 526ac04ccc9fd831ff4d5446ea32c5265e29b11d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Feb 2026 01:50:44 -0800 Subject: [PATCH 08/81] common function in helper --- cpp/src/cluster/detail/kmeans.cuh | 35 ++---- cpp/src/cluster/detail/kmeans_batched.cuh | 136 +++++++++------------- cpp/src/cluster/detail/kmeans_common.cuh | 59 ++++++++++ 3 files changed, 127 insertions(+), 103 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 0a27d3b351..b44d44c570 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -280,31 +280,16 @@ void update_centroids(raft::resources const& handle, rmm::device_uvector& workspace) { auto n_clusters = centroids.extent(0); - auto n_samples = X.extent(0); - - workspace.resize(n_samples, raft::resource::get_cuda_stream(handle)); - - // Calculates weighted sum of all the samples assigned to cluster-i and stores the - // result in new_centroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - cluster_labels, - sample_weights.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - n_clusters, - new_centroids.data_handle(), - raft::resource::get_cuda_stream(handle)); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - cluster_labels, - weight_per_cluster.data_handle(), - (IndexT)1, - (IndexT)sample_weights.extent(0), - (IndexT)n_clusters, - raft::resource::get_cuda_stream(handle)); + + // Compute weighted sums and counts per cluster + cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, + X, + sample_weights, + cluster_labels, + static_cast(n_clusters), + new_centroids, + weight_per_cluster, + workspace); // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 4db84b12c0..7eae580fda 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -56,10 +57,10 @@ void prepare_init_sample(raft::resources const& handle, raft::device_matrix_view X_sample, uint64_t seed) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_samples_out = X_sample.extent(0); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_samples_out = X_sample.extent(0); std::mt19937 gen(seed); std::vector indices(n_samples); @@ -95,9 +96,9 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto n_clusters = params.n_clusters; // Sample size for initialization: at least 3 * n_clusters, but not more than n_samples - size_t init_sample_size = std::min(static_cast(n_samples), - std::max(static_cast(3 * n_clusters), - static_cast(10000))); + size_t init_sample_size = + std::min(static_cast(n_samples), + std::max(static_cast(3 * n_clusters), static_cast(10000))); RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); @@ -116,10 +117,7 @@ void init_centroids_from_host_sample(raft::resources const& handle, workspace); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { // Just use the first n_clusters samples - raft::copy(centroids.data_handle(), - init_sample.data_handle(), - n_clusters * n_features, - stream); + raft::copy(centroids.data_handle(), init_sample.data_handle(), n_clusters * n_features, stream); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { // Centroids already provided, nothing to do } else { @@ -143,48 +141,32 @@ void accumulate_batch_centroids( raft::device_vector_view cluster_counts) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = batch_data.extent(0); auto n_features = batch_data.extent(1); auto n_clusters = centroid_sums.extent(0); // Get workspace from handle - auto* workspace_resource = raft::resource::get_workspace_resource(handle); - auto workspace = rmm::device_uvector(n_samples, stream, workspace_resource); + auto workspace = rmm::device_uvector( + batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); // Temporary buffers for this batch's partial results auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); - // Zero the batch temporaries - raft::matrix::fill(handle, batch_sums.view(), DataT{0}); - raft::matrix::fill(handle, batch_counts.view(), DataT{0}); - // Extract cluster labels from KeyValuePair cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; thrust::transform_iterator, const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); - // Compute weighted sum of samples per cluster for this batch - raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), - batch_data.extent(1), - labels_itr, - sample_weights.data_handle(), - workspace.data(), - batch_data.extent(0), - batch_data.extent(1), - n_clusters, - batch_sums.data_handle(), - stream); - - // Compute sum of weights per cluster for this batch - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - labels_itr, - batch_counts.data_handle(), - static_cast(1), - static_cast(n_samples), - static_cast(n_clusters), - stream); + // Compute weighted sums and counts per cluster for this batch + cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, + batch_data, + sample_weights, + labels_itr, + static_cast(n_clusters), + batch_sums.view(), + batch_counts.view(), + workspace); // Add batch results to running accumulators raft::linalg::add(centroid_sums.data_handle(), @@ -234,10 +216,7 @@ void minibatch_update_centroids(raft::resources const& handle, raft::div_checkzero_op{}); // Step 1: Update total_counts = total_counts + batch_counts - raft::linalg::add(handle, - raft::make_const_mdspan(total_counts), - batch_counts, - total_counts); + raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); // Step 2: Compute learning rates: lr = batch_count / total_count (after update) auto learning_rates = raft::make_device_vector(handle, n_clusters); @@ -292,12 +271,11 @@ void finalize_centroids(raft::resources const& handle, // Divide by counts: new_centroids[i] = centroid_sums[i] / cluster_counts[i] // When count is 0, set to 0 (will be fixed below) - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(new_centroids), - cluster_counts, - new_centroids, - raft::div_checkzero_op{}); + raft::linalg::matrix_vector_op(handle, + raft::make_const_mdspan(new_centroids), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); // Copy old centroids to new centroids where cluster_counts[i] == 0 cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); @@ -351,17 +329,15 @@ void fit(raft::resources const& handle, RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); raft::default_logger().set_level(params.verbosity); - RAFT_LOG_DEBUG( - "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", - static_cast(n_samples), - static_cast(n_features), - n_clusters, - static_cast(batch_size)); + RAFT_LOG_DEBUG("KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(batch_size)); rmm::device_uvector workspace(0, stream); @@ -384,9 +360,9 @@ void fit(raft::resources const& handle, rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // Accumulators for centroid computation - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto cluster_counts = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); // For mini-batch mode: track total counts for learning rate calculation auto total_counts = raft::make_device_vector(handle, n_clusters); @@ -444,8 +420,10 @@ void fit(raft::resources const& handle, X.data_handle() + sample_idx * n_features, n_features * sizeof(DataT)); } - raft::copy( - batch_data.data_handle(), host_batch_buffer.data(), current_batch_size * n_features, stream); + raft::copy(batch_data.data_handle(), + host_batch_buffer.data(), + current_batch_size * n_features, + stream); } else { // Full-batch: sequential access raft::copy(batch_data.data_handle(), @@ -460,7 +438,8 @@ void fit(raft::resources const& handle, for (IdxT i = 0; i < current_batch_size; ++i) { host_weight_buffer[i] = sample_weight->data_handle()[sample_indices[batch_idx + i]]; } - raft::copy(batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); + raft::copy( + batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); } else { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + batch_idx, @@ -485,12 +464,11 @@ void fit(raft::resources const& handle, // Compute L2 norms for batch if needed if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + raft::linalg::rowNorm(L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); } // Find nearest centroid for each sample in batch @@ -523,11 +501,11 @@ void fit(raft::resources const& handle, } accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - cluster_counts.view()); + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); if (use_minibatch) { // Mini-batch mode: update centroids immediately after each batch @@ -560,8 +538,8 @@ void fit(raft::resources const& handle, centroids.data_handle(), n_clusters, n_features); auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( - cluster_counts.data_handle(), n_clusters); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); @@ -616,8 +594,11 @@ void fit(raft::resources const& handle, if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); + raft::linalg::rowNorm(L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); } auto centroids_const = raft::make_device_matrix_view( @@ -652,4 +633,3 @@ void fit(raft::resources const& handle, } } // namespace cuvs::cluster::kmeans::batched::detail - diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index ea93b764f9..62065eae7b 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -469,4 +470,62 @@ void countSamplesInCluster(raft::resources const& handle, (IndexT)n_clusters, workspace); } + +/** + * @brief Compute centroid adjustments (weighted sums and counts per cluster) + * + * This helper function computes: + * 1. Weighted sum of samples per cluster using reduce_rows_by_key + * 2. Sum of weights per cluster using reduce_cols_by_key + * + * @tparam DataT Data type for samples and weights + * @tparam IndexT Index type + * @tparam LabelsIterator Iterator type for cluster labels + * + * @param[in] handle RAFT resources handle + * @param[in] X Input samples [n_samples x n_features] + * @param[in] sample_weights Weights for each sample [n_samples] + * @param[in] cluster_labels Cluster assignment for each sample (iterator) + * @param[in] n_clusters Number of clusters + * @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features] + * @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters] + * @param[inout] workspace Workspace buffer for intermediate operations + */ +template +void compute_centroid_adjustments( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_vector_view sample_weights, + LabelsIterator cluster_labels, + IndexT n_clusters, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + + workspace.resize(n_samples, stream); + + // Compute weighted sum of samples per cluster + raft::linalg::reduce_rows_by_key(const_cast(X.data_handle()), + X.extent(1), + cluster_labels, + sample_weights.data_handle(), + workspace.data(), + X.extent(0), + X.extent(1), + n_clusters, + centroid_sums.data_handle(), + stream); + + // Compute sum of weights per cluster + raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), + cluster_labels, + weight_per_cluster.data_handle(), + static_cast(1), + static_cast(n_samples), + n_clusters, + stream); +} } // namespace cuvs::cluster::kmeans::detail From 9b6f1ef5a8c20602484e5025be85386ddfee62d5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Feb 2026 13:42:45 -0800 Subject: [PATCH 09/81] fix templates --- c/include/cuvs/cluster/kmeans.h | 33 ++ cpp/CMakeLists.txt | 1 + cpp/include/cuvs/cluster/kmeans.hpp | 69 ++++ cpp/src/cluster/detail/kmeans_batched.cuh | 314 +++++++++++------- .../kmeans_fit_batched_int8_uint8_half.cu | 57 ++++ cpp/src/cluster/kmeans_fit_double.cu | 23 +- cpp/src/cluster/kmeans_fit_float.cu | 23 +- python/cuvs/cuvs/cluster/kmeans/__init__.py | 6 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 11 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 118 ++++++- 10 files changed, 489 insertions(+), 166 deletions(-) create mode 100644 cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index e39b9c8523..99e7bfb5d3 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -235,6 +235,39 @@ cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, DLManagedTensor* centroids, double* cost); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * This function processes data from HOST memory in batches, streaming + * to the GPU. Useful when the dataset is too large to fit in GPU memory. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. Must be on DEVICE memory. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + int64_t batch_size, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int64_t* n_iter); /** * @} */ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6313db71ca..fe27415948 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -367,6 +367,7 @@ if(NOT BUILD_CPU_ONLY) src/cluster/kmeans_balanced_fit_predict_int8.cu src/cluster/kmeans_balanced_predict_int8.cu src/cluster/kmeans_balanced_predict_uint8.cu + src/cluster/kmeans_fit_batched_int8_uint8_half.cu src/cluster/kmeans_transform_double.cu src/cluster/kmeans_transform_float.cu src/cluster/single_linkage_float.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 7ce492b136..50b25dbea2 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -299,6 +299,75 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); +/** + * @brief Find clusters with k-means algorithm using batched processing for uint8 data. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory (uint8). + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host, float). + * @param[inout] centroids Cluster centers on device (float, as centroids are averages). + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing for int8 data. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory (int8). + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host, float). + * @param[inout] centroids Cluster centers on device (float, as centroids are averages). + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing for half data. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory (half). + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host, float). + * @param[inout] centroids Cluster centers on device (float, as centroids are averages). + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + /** * @} */ diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 7eae580fda..a9fea3882a 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -41,21 +42,25 @@ #include #include #include +#include #include namespace cuvs::cluster::kmeans::batched::detail { /** - * @brief Sample data from host to device for initialization + * @brief Sample data from host to device for initialization, with optional type conversion * - * Samples `n_samples_to_gather` rows from host data and copies to device. - * Uses uniform strided sampling for simplicity and cache-friendliness. + * @tparam T Input data type + * @tparam MathT Computation/output type + * @tparam IdxT Index type + * @tparam MappingOpT Mapping operator (T -> MathT) */ -template +template void prepare_init_sample(raft::resources const& handle, - raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed) + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed, + MappingOpT mapping_op) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -67,28 +72,44 @@ void prepare_init_sample(raft::resources const& handle, std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), gen); - std::vector host_sample(n_samples_out * n_features); + // Sample raw T data to host buffer + std::vector host_sample(n_samples_out * n_features); #pragma omp parallel for for (IdxT i = 0; i < static_cast(n_samples_out); i++) { IdxT src_idx = indices[i]; std::memcpy(host_sample.data() + i * n_features, X.data_handle() + src_idx * n_features, - n_features * sizeof(DataT)); + n_features * sizeof(T)); } - raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); + if constexpr (std::is_same_v) { + // Same type: direct copy + raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); + } else { + // Different types: copy raw, then convert on GPU + auto raw_sample = raft::make_device_matrix(handle, n_samples_out, n_features); + raft::copy(raw_sample.data_handle(), host_sample.data(), host_sample.size(), stream); + raft::linalg::unaryOp( + X_sample.data_handle(), raw_sample.data_handle(), host_sample.size(), mapping_op, stream); + } } /** * @brief Initialize centroids using k-means++ on a sample of the host data + * + * @tparam T Input data type + * @tparam MathT Computation/centroid type + * @tparam IdxT Index type + * @tparam MappingOpT Mapping operator (T -> MathT) */ -template +template void init_centroids_from_host_sample(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) + raft::host_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace, + MappingOpT mapping_op) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -102,16 +123,16 @@ void init_centroids_from_host_sample(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); - // Sample data from host to device - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); + // Sample data from host to device (with conversion if needed) + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed, mapping_op); // Run k-means++ on the sample if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( + cuvs::cluster::kmeans::detail::kmeansPlusPlus( handle, params, - raft::make_device_matrix_view( + raft::make_device_matrix_view( init_sample.data_handle(), init_sample_size, n_features), centroids, workspace); @@ -131,14 +152,14 @@ void init_centroids_from_host_sample(raft::resources const& handle, * This function adds the partial sums from a batch to the running accumulators. * It does NOT divide - that happens once at the end of all batches. */ -template +template void accumulate_batch_centroids( raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IdxT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts) + raft::device_matrix_view batch_data, + raft::device_vector_view, IdxT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_features = batch_data.extent(1); @@ -149,13 +170,13 @@ void accumulate_batch_centroids( batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); // Temporary buffers for this batch's partial results - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); // Extract cluster labels from KeyValuePair - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); // Compute weighted sums and counts per cluster for this batch @@ -191,18 +212,18 @@ void accumulate_batch_centroids( * * This is equivalent to a weighted average where total_count tracks cumulative weight. */ -template +template void minibatch_update_centroids(raft::resources const& handle, - raft::device_matrix_view centroids, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts, - raft::device_vector_view total_counts) + raft::device_matrix_view centroids, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_vector_view total_counts) { auto n_clusters = centroids.extent(0); auto n_features = centroids.extent(1); // Compute batch means: batch_mean = batch_sums / batch_counts - auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); raft::copy(batch_means.data_handle(), batch_sums.data_handle(), batch_sums.size(), @@ -219,7 +240,7 @@ void minibatch_update_centroids(raft::resources const& handle, raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); // Step 2: Compute learning rates: lr = batch_count / total_count (after update) - auto learning_rates = raft::make_device_vector(handle, n_clusters); + auto learning_rates = raft::make_device_vector(handle, n_clusters); raft::linalg::map(handle, learning_rates.view(), raft::div_checkzero_op{}, @@ -234,7 +255,7 @@ void minibatch_update_centroids(raft::resources const& handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(learning_rates.view()), centroids, - [] __device__(DataT centroid_val, DataT lr) { return (DataT{1} - lr) * centroid_val; }); + [] __device__(MathT centroid_val, MathT lr) { return (MathT{1} - lr) * centroid_val; }); // Add lr * batch_mean to centroids raft::linalg::matrix_vector_op( @@ -242,7 +263,7 @@ void minibatch_update_centroids(raft::resources const& handle, raft::make_const_mdspan(batch_means.view()), raft::make_const_mdspan(learning_rates.view()), batch_means.view(), - [] __device__(DataT mean_val, DataT lr) { return lr * mean_val; }); + [] __device__(MathT mean_val, MathT lr) { return lr * mean_val; }); // centroids += lr * batch_means raft::linalg::add(handle, @@ -254,12 +275,12 @@ void minibatch_update_centroids(raft::resources const& handle, /** * @brief Finalize centroids by dividing accumulated sums by counts */ -template +template void finalize_centroids(raft::resources const& handle, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_clusters = new_centroids.extent(0); @@ -278,7 +299,7 @@ void finalize_centroids(raft::resources const& handle, raft::div_checkzero_op{}); // Copy old centroids to new centroids where cluster_counts[i] == 0 - cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); raft::matrix::gather_if( old_centroids.data_handle(), static_cast(old_centroids.extent(1)), @@ -287,8 +308,8 @@ void finalize_centroids(raft::resources const& handle, itr_wt, static_cast(cluster_counts.size()), new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { - return map.value == DataT{0}; // predicate: copy when count is 0 + [=] __device__(raft::KeyValuePair map) { + return map.value == MathT{0}; // predicate: copy when count is 0 }, raft::key_op{}, stream); @@ -297,27 +318,34 @@ void finalize_centroids(raft::resources const& handle, /** * @brief Main fit function for batched k-means with host data * - * @tparam DataT Data type (float, double) - * @tparam IdxT Index type (int, int64_t) + * This is a unified function that handles both same-type (T == MathT) and + * mixed-type (T != MathT) cases, following the kmeans_balanced pattern. + * + * @tparam T Input data type (float, double, uint8_t, int8_t, half) + * @tparam MathT Computation/centroid type (typically float) + * @tparam IdxT Index type (int, int64_t) + * @tparam MappingOpT Mapping operator (T -> MathT) * * @param[in] handle RAFT resources handle * @param[in] params K-means parameters - * @param[in] X Input data on HOST [n_samples x n_features] + * @param[in] X Input data on HOST [n_samples x n_features] * @param[in] batch_size Number of samples to process per batch - * @param[in] sample_weight Optional weights per sample (on host) + * @param[in] sample_weight Optional weights per sample (on host, MathT type) * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run + * @param[in] mapping_op Mapping operator for T -> MathT conversion */ -template +template void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, + raft::host_matrix_view X, IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + MappingOpT mapping_op) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -343,37 +371,39 @@ void fit(raft::resources const& handle, // Initialize centroids from a sample of host data if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample(handle, params, X, centroids, workspace); + init_centroids_from_host_sample(handle, params, X, centroids, workspace, mapping_op); } // Allocate device buffers - // Batch buffer for data - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); - // Batch buffer for weights - auto batch_weights = raft::make_device_vector(handle, batch_size); - // Cluster assignment for batch + // For mixed types, we need a raw buffer for T and a converted buffer for MathT + // For same types, we only need one buffer + rmm::device_uvector batch_data_raw(0, stream); + if constexpr (!std::is_same_v) { + batch_data_raw.resize(batch_size * n_features, stream); + } + + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, batch_size); - // L2 norms of batch data - auto L2NormBatch = raft::make_device_vector(handle, batch_size); - // Temporary buffer for distance computation - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + raft::make_device_vector, IdxT>(handle, batch_size); + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // Accumulators for centroid computation - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto cluster_counts = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); // For mini-batch mode: track total counts for learning rate calculation - auto total_counts = raft::make_device_vector(handle, n_clusters); + auto total_counts = raft::make_device_vector(handle, n_clusters); - // Host buffer for batch data (pinned memory for faster H2D transfer) - std::vector host_batch_buffer(batch_size * n_features); - std::vector host_weight_buffer(batch_size); + // Host buffer for batch data + std::vector host_batch_buffer(batch_size * n_features); + std::vector host_weight_buffer(batch_size); // Cluster cost for convergence check - rmm::device_scalar clusterCostD(stream); - DataT priorClusteringCost = 0; + rmm::device_scalar clusterCostD(stream); + MathT priorClusteringCost = 0; // Check update mode bool use_minibatch = @@ -393,11 +423,11 @@ void fit(raft::resources const& handle, // For full-batch mode: zero accumulators at start of each iteration // For mini-batch mode: zero total_counts at start of each iteration if (!use_minibatch) { - raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); - raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); + raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); } else { // Mini-batch mode: zero total counts for learning rate calculation - raft::matrix::fill(handle, total_counts.view(), DataT{0}); + raft::matrix::fill(handle, total_counts.view(), MathT{0}); // Shuffle sample indices for random batch selection std::shuffle(sample_indices.begin(), sample_indices.end(), rng); } @@ -405,7 +435,7 @@ void fit(raft::resources const& handle, // Save old centroids for convergence check raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - DataT total_cost = 0; + MathT total_cost = 0; // Process all data in batches for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { @@ -414,22 +444,48 @@ void fit(raft::resources const& handle, // Copy batch data from host to device if (use_minibatch) { // Mini-batch: use shuffled indices for random sampling +#pragma omp parallel for for (IdxT i = 0; i < current_batch_size; ++i) { IdxT sample_idx = sample_indices[batch_idx + i]; std::memcpy(host_batch_buffer.data() + i * n_features, X.data_handle() + sample_idx * n_features, - n_features * sizeof(DataT)); + n_features * sizeof(T)); + } + + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + host_batch_buffer.data(), + current_batch_size * n_features, + stream); + } else { + raft::copy(batch_data_raw.data(), + host_batch_buffer.data(), + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); } - raft::copy(batch_data.data_handle(), - host_batch_buffer.data(), - current_batch_size * n_features, - stream); } else { // Full-batch: sequential access - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + } else { + raft::copy(batch_data_raw.data(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); + } } // Copy or set weights for this batch @@ -447,18 +503,18 @@ void fit(raft::resources const& handle, stream); } } else { - auto batch_weights_fill_view = raft::make_device_vector_view( + auto batch_weights_fill_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); - raft::matrix::fill(handle, batch_weights_fill_view, DataT{1}); + raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); } // Create views for current batch size - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); // Compute L2 norms for batch if needed @@ -472,12 +528,12 @@ void fit(raft::resources const& handle, } // Find nearest centroid for each sample in batch - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -491,16 +547,16 @@ void fit(raft::resources const& handle, // Accumulate partial sums for this batch auto minClusterAndDistance_const = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (use_minibatch) { // Mini-batch mode: zero batch accumulators before each batch - raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); - raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); + raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); } - accumulate_batch_centroids(handle, + accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, batch_weights_view, @@ -509,12 +565,12 @@ void fit(raft::resources const& handle, if (use_minibatch) { // Mini-batch mode: update centroids immediately after each batch - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( + auto cluster_counts_const = raft::make_device_vector_view( cluster_counts.data_handle(), n_clusters); - minibatch_update_centroids( + minibatch_update_centroids( handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); } @@ -527,26 +583,26 @@ void fit(raft::resources const& handle, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, raft::add_op{}); - DataT batch_cost = clusterCostD.value(stream); + MathT batch_cost = clusterCostD.value(stream); total_cost += batch_cost; } } // end batch loop if (!use_minibatch) { // Full-batch mode: finalize centroids after processing all batches - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - finalize_centroids( + finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); } // Compute squared norm of change in centroids (compare to saved old centroids) - auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + auto sqrdNorm = raft::make_device_scalar(handle, MathT{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), centroids.size(), raft::sqdiff_op{}, @@ -554,14 +610,14 @@ void fit(raft::resources const& handle, new_centroids.data_handle(), // old centroids centroids.data_handle()); // new centroids - DataT sqrdNormError = 0; + MathT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); // Check convergence bool done = false; if (params.inertia_check) { if (n_iter[0] > 1) { - DataT delta = total_cost / priorClusteringCost; + MathT delta = total_cost / priorClusteringCost; if (delta > 1 - params.tol) done = true; } priorClusteringCost = total_cost; @@ -581,15 +637,27 @@ void fit(raft::resources const& handle, for (IdxT offset = 0; offset < n_samples; offset += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - offset); - raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + } else { + raft::copy(batch_data_raw.data(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); + } - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -601,12 +669,12 @@ void fit(raft::resources const& handle, stream); } - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -629,7 +697,7 @@ void fit(raft::resources const& handle, inertia[0] += clusterCostD.value(stream); } - RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", inertia[0]); + RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); } } // namespace cuvs::cluster::kmeans::batched::detail diff --git a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu new file mode 100644 index 0000000000..f4f5d90b68 --- /dev/null +++ b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../neighbors/detail/ann_utils.cuh" +#include "detail/kmeans_batched.cuh" +#include + +#include + +namespace cuvs::cluster::kmeans { + +// Use the mapping struct from ann_utils for T -> float conversion +using cuvs::spatial::knn::detail::utils::mapping; + +// Public API implementations - X is T (uint8/int8/half) but centroids are float +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); +} + +} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 0962c87890..fc0c9fd335 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -31,29 +31,14 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); -#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ - template void batched::detail::fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::host_matrix_view X, \ - IndexT batch_size, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) INSTANTIATE_FIT(double, int) INSTANTIATE_FIT(double, int64_t) -INSTANTIATE_FIT_BATCHED(double, int) -INSTANTIATE_FIT_BATCHED(double, int64_t) - #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT -#undef INSTANTIATE_FIT_BATCHED void fit_batched(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -64,8 +49,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } void fit_batched(raft::resources const& handle, @@ -77,8 +62,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } void fit(raft::resources const& handle, diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 39cc074d9e..50ee522b4a 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -31,29 +31,14 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); -#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ - template void batched::detail::fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::host_matrix_view X, \ - IndexT batch_size, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) INSTANTIATE_FIT(float, int) INSTANTIATE_FIT(float, int64_t) -INSTANTIATE_FIT_BATCHED(float, int) -INSTANTIATE_FIT_BATCHED(float, int64_t) - #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT -#undef INSTANTIATE_FIT_BATCHED void fit_batched(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -64,8 +49,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } void fit_batched(raft::resources const& handle, @@ -77,8 +62,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } void fit(raft::resources const& handle, diff --git a/python/cuvs/cuvs/cluster/kmeans/__init__.py b/python/cuvs/cuvs/cluster/kmeans/__init__.py index f4765bcb64..56c7645ebb 100644 --- a/python/cuvs/cuvs/cluster/kmeans/__init__.py +++ b/python/cuvs/cuvs/cluster/kmeans/__init__.py @@ -1,7 +1,7 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -from .kmeans import KMeansParams, cluster_cost, fit, predict +from .kmeans import KMeansParams, cluster_cost, fit, fit_batched, predict -__all__ = ["KMeansParams", "cluster_cost", "fit", "predict"] +__all__ = ["KMeansParams", "cluster_cost", "fit", "fit_batched", "predict"] diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index d219f4e903..0e20d6a709 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -4,7 +4,7 @@ # # cython: language_level=3 -from libc.stdint cimport uintptr_t +from libc.stdint cimport int64_t, uintptr_t from libcpp cimport bool from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t @@ -68,3 +68,12 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: DLManagedTensor* X, DLManagedTensor* centroids, double* cost) + + cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + int64_t batch_size, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int64_t* n_iter) except + diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 7b0cb4b3a2..4eb6cf3c6b 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 @@ -34,6 +34,7 @@ from libc.stdint cimport ( uint64_t, uintptr_t, ) +from libc.stdlib cimport free, malloc from cuvs.common.exceptions import check_cuvs @@ -431,3 +432,118 @@ def cluster_cost(X, centroids, resources=None): &inertia)) return inertia + + +@auto_sync_resources +@auto_convert_output +def fit_batched( + KMeansParams params, X, batch_size, centroids=None, sample_weights=None, + resources=None +): + """ + Find clusters with the k-means algorithm using batched processing. + + This function processes data from HOST memory in batches, streaming + to the GPU. Useful when the dataset is too large to fit in GPU memory. + + Parameters + ---------- + + params : KMeansParams + Parameters to use to fit KMeans model + X : numpy array or array with __array_interface__ + Input HOST memory array shape (n_samples, n_features). + Must be C-contiguous. Supported dtypes: float32, float64, uint8, int8, float16. + batch_size : int + Number of samples to process per batch. Recommended: 500K-2M + depending on GPU memory. + centroids : Optional writable CUDA array interface compliant matrix + shape (n_clusters, n_features) + sample_weights : Optional input HOST memory array shape (n_samples,) + default: None + {resources_docstring} + + Returns + ------- + centroids : raft.device_ndarray + The computed centroids for each cluster (on device) + inertia : float + Sum of squared distances of samples to their closest cluster center + n_iter : int + The number of iterations used to fit the model + + Examples + -------- + + >>> import numpy as np + >>> import cupy as cp + >>> + >>> from cuvs.cluster.kmeans import fit_batched, KMeansParams + >>> + >>> n_samples = 10_000_000 + >>> n_features = 128 + >>> n_clusters = 1000 + >>> + >>> # Data on host (numpy array) + >>> X = np.random.random((n_samples, n_features)).astype(np.float32) + >>> + >>> params = KMeansParams(n_clusters=n_clusters, max_iter=20) + >>> centroids, inertia, n_iter = fit_batched(params, X, batch_size=1_000_000) + """ + # Ensure X is a numpy array (host memory) + if not isinstance(X, np.ndarray): + X = np.asarray(X) + + if not X.flags['C_CONTIGUOUS']: + X = np.ascontiguousarray(X) + + _check_input_array(wrap_array(X), [np.dtype('float32'), np.dtype('float64'), + np.dtype('uint8'), np.dtype('int8'), + np.dtype('float16')]) + + cdef int64_t n_samples = X.shape[0] + cdef int64_t n_features = X.shape[1] + + # Create DLPack tensor for host data + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(wrap_array(X)) + cdef cydlpack.DLManagedTensor* sample_weight_dlpack = NULL + + cdef cuvsResources_t res = resources.get_c_obj() + + cdef double inertia = 0 + cdef int64_t n_iter = 0 + cdef int64_t c_batch_size = batch_size + + if centroids is None: + # For integer/half types, centroids are always float32 + # (centroids are averages, can't be represented as integers) + if X.dtype in (np.uint8, np.int8, np.float16): + centroids_dtype = np.float32 + else: + centroids_dtype = X.dtype + centroids = device_ndarray.empty((params.n_clusters, n_features), + dtype=centroids_dtype) + + centroids_ai = wrap_array(centroids) + cdef cydlpack.DLManagedTensor* centroids_dlpack = \ + cydlpack.dlpack_c(centroids_ai) + + if sample_weights is not None: + if not isinstance(sample_weights, np.ndarray): + sample_weights = np.asarray(sample_weights) + if not sample_weights.flags['C_CONTIGUOUS']: + sample_weights = np.ascontiguousarray(sample_weights) + sample_weight_dlpack = cydlpack.dlpack_c(wrap_array(sample_weights)) + + with cuda_interruptible(): + check_cuvs(cuvsKMeansFitBatched( + res, + params.params, + x_dlpack, + c_batch_size, + sample_weight_dlpack, + centroids_dlpack, + &inertia, + &n_iter)) + + return FitOutput(centroids, inertia, n_iter) From 4b65df5fc3a059d5f8573bedaa614a5b75f73e9a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Feb 2026 14:26:53 -0800 Subject: [PATCH 10/81] namespace and init fixes --- cpp/include/cuvs/cluster/kmeans.hpp | 71 +++++++++++++++---- cpp/src/cluster/detail/kmeans_batched.cuh | 22 +++--- .../kmeans_fit_batched_int8_uint8_half.cu | 6 +- cpp/src/cluster/kmeans_fit_double.cu | 4 +- 4 files changed, 73 insertions(+), 30 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 50b25dbea2..5da41c7750 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -50,16 +50,20 @@ struct params : base_params { /** * Centroid update mode determines when centroids are updated during training. + * This is primarily used with fit_batched() for out-of-core / host data processing. */ enum CentroidUpdateMode { /** - * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * Standard k-means (Lloyd's algorithm): accumulate partial sums over the * entire dataset, then update centroids once per iteration. */ FullBatch, /** - * Mini-batch k-means: update centroids after each randomly sampled batch. + * Mini-batch k-means: update centroids incrementally after each randomly + * sampled batch using an online learning rule. Converges faster but may + * produce slightly different results each run. Recommended for very large + * datasets where multiple full passes are too expensive. */ MiniBatch }; @@ -123,9 +127,12 @@ struct params : base_params { int batch_centroids = 0; /** - * Centroid update mode: - * - FullBatch: Standard Lloyd's algorithm, update centroids after full dataset pass - * - MiniBatch: Mini-batch k-means, update centroids after each batch + * Centroid update mode for fit_batched(): + * - FullBatch (default): Standard Lloyd's algorithm. Accumulate partial sums + * across all batches, update centroids once per iteration. Deterministic and + * mathematically equivalent to standard k-means. + * - MiniBatch: Online mini-batch k-means. Update centroids incrementally after + * each randomly sampled batch. Faster convergence but non-deterministic. */ CentroidUpdateMode update_mode = FullBatch; @@ -171,9 +178,19 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @brief Find clusters with k-means algorithm using batched processing. * * This version supports out-of-core computation where the dataset resides - * on the host. Data is processed in batches, with partial sums accumulated - * across batches and centroids finalized at the end of each iteration. - * This is mathematically equivalent to standard kmeans. + * on the host. Data is processed in batches, streaming from host to device. + * + * Two centroid update modes are supported (controlled by params.update_mode): + * + * - **FullBatch** (default): Standard Lloyd's algorithm. Partial sums are + * accumulated across all batches, and centroids are updated once at the + * end of each iteration. This is mathematically equivalent to standard + * k-means and produces deterministic results. + * + * - **MiniBatch**: Mini-batch k-means. Centroids are updated incrementally + * after each randomly sampled batch using an online learning rule. This + * converges faster but may produce slightly different results each run. + * Useful for very large datasets where full passes are expensive. * * @code{.cpp} * #include @@ -182,6 +199,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * ... * raft::resources handle; * cuvs::cluster::kmeans::params params; + * params.n_clusters = 100; + * // params.update_mode = kmeans::params::MiniBatch; // for mini-batch mode * int n_features = 15; * float inertia; * int n_iter; @@ -204,7 +223,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @endcode * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model. Use params.update_mode + * to select FullBatch or MiniBatch mode. * @param[in] X Training instances on HOST memory. The data must * be in row-major format. * [dim = n_samples x n_features] @@ -233,8 +253,11 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. See the int-indexed overload for detailed documentation. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -256,8 +279,11 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. See the float int-indexed overload for detailed documentation. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -279,8 +305,11 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. See the float int-indexed overload for detailed documentation. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -302,8 +331,12 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for uint8 data. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. Input data is uint8 but centroids are float (since + * centroids are averages). Conversion happens on GPU using mapping operators. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (uint8). * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -325,8 +358,12 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for int8 data. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. Input data is int8 but centroids are float (since + * centroids are averages). Conversion happens on GPU using mapping operators. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (int8). * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -348,8 +385,12 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for half data. * + * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via + * params.update_mode. Input data is half but centroids are float (for + * numerical stability). Conversion happens on GPU using mapping operators. + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (half). * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index a9fea3882a..3287676ec3 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -45,7 +45,7 @@ #include #include -namespace cuvs::cluster::kmeans::batched::detail { +namespace cuvs::cluster::kmeans::detail { /** * @brief Sample data from host to device for initialization, with optional type conversion @@ -127,15 +127,17 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed, mapping_op); - // Run k-means++ on the sample if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, - params, - raft::make_device_matrix_view( - init_sample.data_handle(), init_sample_size, n_features), - centroids, - workspace); + auto init_sample_view = raft::make_device_matrix_view( + init_sample.data_handle(), init_sample_size, n_features); + + if (params.oversampling_factor == 0) { + cuvs::cluster::kmeans::detail::kmeansPlusPlus( + handle, params, init_sample_view, centroids, workspace); + } else { + cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( + handle, params, init_sample_view, centroids, workspace); + } } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { // Just use the first n_clusters samples raft::copy(centroids.data_handle(), init_sample.data_handle(), n_clusters * n_features, stream); @@ -700,4 +702,4 @@ void fit(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); } -} // namespace cuvs::cluster::kmeans::batched::detail +} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu index f4f5d90b68..10d7289132 100644 --- a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu +++ b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu @@ -24,7 +24,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); } @@ -37,7 +37,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); } @@ -50,7 +50,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); } diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index fc0c9fd335..4986db5146 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -49,7 +49,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } @@ -62,7 +62,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } From 5eb2be57ff69536cf13a57d510d3283c1e70f193 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Feb 2026 14:30:46 -0800 Subject: [PATCH 11/81] fix docs in main header --- cpp/include/cuvs/cluster/kmeans.hpp | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 5da41c7750..8bd34914bf 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -62,8 +62,7 @@ struct params : base_params { /** * Mini-batch k-means: update centroids incrementally after each randomly * sampled batch using an online learning rule. Converges faster but may - * produce slightly different results each run. Recommended for very large - * datasets where multiple full passes are too expensive. + * produce slightly different results each run. */ MiniBatch }; @@ -253,9 +252,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. See the int-indexed overload for detailed documentation. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. @@ -279,9 +275,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. See the float int-indexed overload for detailed documentation. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. @@ -305,9 +298,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. See the float int-indexed overload for detailed documentation. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory. @@ -331,10 +321,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for uint8 data. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. Input data is uint8 but centroids are float (since - * centroids are averages). Conversion happens on GPU using mapping operators. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (uint8). @@ -358,10 +344,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for int8 data. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. Input data is int8 but centroids are float (since - * centroids are averages). Conversion happens on GPU using mapping operators. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (int8). @@ -385,10 +367,6 @@ void fit_batched(raft::resources const& handle, /** * @brief Find clusters with k-means algorithm using batched processing for half data. * - * Supports both FullBatch (Lloyd's algorithm) and MiniBatch modes via - * params.update_mode. Input data is half but centroids are float (for - * numerical stability). Conversion happens on GPU using mapping operators. - * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model (including update_mode). * @param[in] X Training instances on HOST memory (half). From c23985ad6378dbf1ffc5d23bd6630d4fafce94df Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 00:11:22 -0800 Subject: [PATCH 12/81] several fixes --- c/include/cuvs/cluster/kmeans.h | 7 + c/src/cluster/kmeans.cpp | 211 ++++++- cpp/include/cuvs/cluster/kmeans.hpp | 188 ++++++ cpp/src/cluster/detail/kmeans_batched.cuh | 541 ++++++++++++------ .../kmeans_fit_batched_int8_uint8_half.cu | 134 +++++ cpp/src/cluster/kmeans_fit_double.cu | 45 ++ cpp/src/cluster/kmeans_fit_float.cu | 49 +- python/cuvs/cuvs/tests/test_kmeans.py | 2 +- .../cuvs_bench/get_dataset/__main__.py | 295 +++++++++- 9 files changed, 1273 insertions(+), 199 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 99e7bfb5d3..5b157abd6d 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -113,8 +113,15 @@ struct cuvsKMeansParams { */ cuvsKMeansCentroidUpdateMode update_mode; + /** Check inertia during iterations for early convergence. */ bool inertia_check; + /** + * Compute final inertia after fit_batched completes (requires extra data pass). + * Only used by fit_batched; regular fit always computes final inertia. + */ + bool final_inertia_check; + /** * Whether to use hierarchical (balanced) kmeans or not */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 698f387616..52b098ee2e 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -1,11 +1,13 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include #include +#include + #include #include #include @@ -17,16 +19,20 @@ namespace { cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) { - auto kmeans_params = cuvs::cluster::kmeans::params(); - kmeans_params.metric = static_cast(params.metric); - kmeans_params.init = static_cast(params.init); - kmeans_params.n_clusters = params.n_clusters; - kmeans_params.max_iter = params.max_iter; - kmeans_params.tol = params.tol; + auto kmeans_params = cuvs::cluster::kmeans::params(); + kmeans_params.metric = static_cast(params.metric); + kmeans_params.init = static_cast(params.init); + kmeans_params.n_clusters = params.n_clusters; + kmeans_params.max_iter = params.max_iter; + kmeans_params.tol = params.tol; + kmeans_params.n_init = params.n_init; kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; + kmeans_params.final_inertia_check = params.final_inertia_check; + kmeans_params.update_mode = + static_cast(params.update_mode); return kmeans_params; } @@ -177,6 +183,139 @@ void _cluster_cost(cuvsResources_t res, *cost = cost_temp; } + +template +void _fit_batched(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + IdxT batch_size, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + IdxT* n_iter) +{ + auto X = X_tensor->dl_tensor; + auto centroids = centroids_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + auto n_samples = static_cast(X.shape[0]); + auto n_features = static_cast(X.shape[1]); + + // X must be on host (CPU) memory + if (X.device.device_type != kDLCPU) { + RAFT_FAIL("X dataset must be on host (CPU) memory for fit_batched"); + } + + // centroids must be on device memory + if (!cuvs::core::is_dlpack_device_compatible(centroids)) { + RAFT_FAIL("centroids must be on device memory"); + } + + using device_matrix_type = raft::device_matrix_view; + + // Create host matrix view from X + auto X_view = raft::make_host_matrix_view( + reinterpret_cast(X.data), n_samples, n_features); + + // Create device matrix view for centroids + auto centroids_view = cuvs::core::from_dlpack(centroids_tensor); + + // Handle optional sample weights + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + auto sw = sample_weight_tensor->dl_tensor; + if (sw.device.device_type != kDLCPU) { + RAFT_FAIL("sample_weight must be on host (CPU) memory for fit_batched"); + } + sample_weight = raft::make_host_vector_view( + reinterpret_cast(sw.data), n_samples); + } + + T inertia_temp; + IdxT n_iter_temp; + + auto kmeans_params = convert_params(params); + cuvs::cluster::kmeans::fit_batched(*res_ptr, + kmeans_params, + X_view, + batch_size, + sample_weight, + centroids_view, + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + + *inertia = inertia_temp; + *n_iter = n_iter_temp; +} + +// Specialized version for integer/half types where X is InputT but centroids are float +template +void _fit_batched_mixed(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + IdxT batch_size, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + IdxT* n_iter) +{ + auto X = X_tensor->dl_tensor; + auto centroids = centroids_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + auto n_samples = static_cast(X.shape[0]); + auto n_features = static_cast(X.shape[1]); + + // X must be on host (CPU) memory + if (X.device.device_type != kDLCPU) { + RAFT_FAIL("X dataset must be on host (CPU) memory for fit_batched"); + } + + // centroids must be on device memory and float type + if (!cuvs::core::is_dlpack_device_compatible(centroids)) { + RAFT_FAIL("centroids must be on device memory"); + } + if (centroids.dtype.code != kDLFloat || centroids.dtype.bits != 32) { + RAFT_FAIL("centroids must be float32 for integer/half input types"); + } + + using device_matrix_type = raft::device_matrix_view; + + // Create host matrix view from X (InputT) + auto X_view = raft::make_host_matrix_view( + reinterpret_cast(X.data), n_samples, n_features); + + // Create device matrix view for centroids (float) + auto centroids_view = cuvs::core::from_dlpack(centroids_tensor); + + // Handle optional sample weights (float) + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + auto sw = sample_weight_tensor->dl_tensor; + if (sw.device.device_type != kDLCPU) { + RAFT_FAIL("sample_weight must be on host (CPU) memory for fit_batched"); + } + if (sw.dtype.code != kDLFloat || sw.dtype.bits != 32) { + RAFT_FAIL("sample_weight must be float32 for integer/half input types"); + } + sample_weight = raft::make_host_vector_view( + reinterpret_cast(sw.data), n_samples); + } + + float inertia_temp; + IdxT n_iter_temp; + + auto kmeans_params = convert_params(params); + cuvs::cluster::kmeans::fit_batched(*res_ptr, + kmeans_params, + X_view, + batch_size, + sample_weight, + centroids_view, + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + + *inertia = inertia_temp; + *n_iter = n_iter_temp; +} } // namespace extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) @@ -184,17 +323,21 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) return cuvs::core::translate_exceptions([=] { cuvs::cluster::kmeans::params cpp_params; cuvs::cluster::kmeans::balanced_params cpp_balanced_params; - *params = - new cuvsKMeansParams{.metric = static_cast(cpp_params.metric), - .n_clusters = cpp_params.n_clusters, - .init = static_cast(cpp_params.init), - .max_iter = cpp_params.max_iter, - .tol = cpp_params.tol, - .oversampling_factor = cpp_params.oversampling_factor, - .batch_samples = cpp_params.batch_samples, - .inertia_check = cpp_params.inertia_check, - .hierarchical = false, - .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; + *params = new cuvsKMeansParams{ + .metric = static_cast(cpp_params.metric), + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .n_init = cpp_params.n_init, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .batch_centroids = cpp_params.batch_centroids, + .update_mode = static_cast(cpp_params.update_mode), + .inertia_check = cpp_params.inertia_check, + .final_inertia_check = cpp_params.final_inertia_check, + .hierarchical = false, + .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; }); } @@ -267,3 +410,35 @@ extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, } }); } + +extern "C" cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + int64_t batch_size, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int64_t* n_iter) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { + // uint8 input -> float centroids + _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { + // int8 input -> float centroids + _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) { + // half input -> float centroids + _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 8bd34914bf..b8be9d1f0a 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -135,7 +135,19 @@ struct params : base_params { */ CentroidUpdateMode update_mode = FullBatch; + /** + * If true, check inertia during iterations for early convergence (used by both fit and + * fit_batched). + */ bool inertia_check = false; + + /** + * If true, compute the final inertia after fit_batched completes. This requires an additional + * full pass over all the host data, which can be expensive for large datasets. + * Only used by fit_batched(); regular fit() always computes final inertia. + * Default: false (skip final inertia computation for performance). + */ + bool final_inertia_check = false; }; /** @@ -387,6 +399,182 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); +/** + * @defgroup predict_batched Batched K-Means Predict + * @{ + */ + +/** + * @brief Predict cluster labels for host data using batched processing. + * + * Streams data from host to GPU in batches, assigns each sample to its nearest + * centroid, and writes labels back to host memory. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Input samples on HOST memory. [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation (on host). + * @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features] + * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] + * @param[in] normalize_weight Whether to normalize sample weights. + * @param[out] inertia Sum of squared distances to nearest centroid (only if + * params.final_inertia_check is true, otherwise 0). + */ +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict cluster labels for host data using batched processing (double). + */ +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict cluster labels for host uint8 data using batched processing. + * + * Input data is uint8 on host, centroids are float on device. + */ +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict cluster labels for host int8 data using batched processing. + * + * Input data is int8 on host, centroids are float on device. + */ +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict cluster labels for host half data using batched processing. + * + * Input data is half on host, centroids are float on device. + */ +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @} + */ + +/** + * @defgroup fit_predict_batched Batched K-Means Fit + Predict + * @{ + */ + +/** + * @brief Fit k-means and predict cluster labels using batched processing. + * + * Combines fit_batched and predict_batched into a single call. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation (on host). + * @param[inout] centroids Cluster centers on device. [dim = n_clusters x n_features] + * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Fit k-means and predict cluster labels using batched processing (double). + */ +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Fit k-means and predict cluster labels for uint8 data using batched processing. + */ +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Fit k-means and predict cluster labels for int8 data using batched processing. + */ +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Fit k-means and predict cluster labels for half data using batched processing. + */ +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + /** * @} */ diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 3287676ec3..834c2a93ac 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -167,21 +167,17 @@ void accumulate_batch_centroids( auto n_features = batch_data.extent(1); auto n_clusters = centroid_sums.extent(0); - // Get workspace from handle auto workspace = rmm::device_uvector( batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); - // Temporary buffers for this batch's partial results auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); - // Extract cluster labels from KeyValuePair cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; thrust::transform_iterator, const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); - // Compute weighted sums and counts per cluster for this batch cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, batch_data, sample_weights, @@ -226,17 +222,8 @@ void minibatch_update_centroids(raft::resources const& handle, // Compute batch means: batch_mean = batch_sums / batch_counts auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); - raft::copy(batch_means.data_handle(), - batch_sums.data_handle(), - batch_sums.size(), - raft::resource::get_cuda_stream(handle)); - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(batch_means.view()), - batch_counts, - batch_means.view(), - raft::div_checkzero_op{}); + handle, batch_sums, batch_counts, batch_means.view(), raft::div_checkzero_op{}); // Step 1: Update total_counts = total_counts + batch_counts raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); @@ -400,8 +387,7 @@ void fit(raft::resources const& handle, auto total_counts = raft::make_device_vector(handle, n_clusters); // Host buffer for batch data - std::vector host_batch_buffer(batch_size * n_features); - std::vector host_weight_buffer(batch_size); + auto host_batch_buffer = raft::make_host_matrix(batch_size, n_features); // Cluster cost for convergence check rmm::device_scalar clusterCostD(stream); @@ -413,104 +399,77 @@ void fit(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); - // For mini-batch mode with random sampling, create index shuffle - std::vector sample_indices(n_samples); - std::iota(sample_indices.begin(), sample_indices.end(), 0); + // Random number generator for mini-batch sampling std::mt19937 rng(params.rng_state.seed); + // Buffer for sampled indices (only used in mini-batch mode) + std::vector batch_indices(batch_size); + + // For mini-batch: set up sampling distribution (weighted if weights provided) + std::uniform_int_distribution uniform_dist(0, n_samples - 1); + std::discrete_distribution weighted_dist; + bool use_weighted_sampling = false; + if (use_minibatch && sample_weight) { + std::vector weights(sample_weight->data_handle(), + sample_weight->data_handle() + n_samples); + weighted_dist = std::discrete_distribution(weights.begin(), weights.end()); + use_weighted_sampling = true; + } + + // Initialize total_counts once for mini-batch mode (cumulative across all iterations) + if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), MathT{0}); } + // Main iteration loop + // For mini-batch: 1 iteration = 1 batch (process max_iter batches total) + // For full-batch: 1 iteration = 1 full pass over data for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - // For full-batch mode: zero accumulators at start of each iteration - // For mini-batch mode: zero total_counts at start of each iteration - if (!use_minibatch) { - raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); - raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); - } else { - // Mini-batch mode: zero total counts for learning rate calculation - raft::matrix::fill(handle, total_counts.view(), MathT{0}); - // Shuffle sample indices for random batch selection - std::shuffle(sample_indices.begin(), sample_indices.end(), rng); - } - // Save old centroids for convergence check raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); MathT total_cost = 0; - // Process all data in batches - for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + if (use_minibatch) { + // ========== MINI-BATCH MODE: Process ONE batch per iteration ========== + IdxT current_batch_size = batch_size; - // Copy batch data from host to device - if (use_minibatch) { - // Mini-batch: use shuffled indices for random sampling -#pragma omp parallel for - for (IdxT i = 0; i < current_batch_size; ++i) { - IdxT sample_idx = sample_indices[batch_idx + i]; - std::memcpy(host_batch_buffer.data() + i * n_features, - X.data_handle() + sample_idx * n_features, - n_features * sizeof(T)); - } + // Sample indices with replacement (weighted if weights provided) + for (IdxT i = 0; i < current_batch_size; ++i) { + batch_indices[i] = use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); + } - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - host_batch_buffer.data(), - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - host_batch_buffer.data(), - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } - } else { - // Full-batch: sequential access - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } +#pragma omp parallel for + for (IdxT i = 0; i < current_batch_size; ++i) { + IdxT sample_idx = batch_indices[i]; + std::memcpy(host_batch_buffer.data_handle() + i * n_features, + X.data_handle() + sample_idx * n_features, + n_features * sizeof(T)); } - // Copy or set weights for this batch - if (sample_weight) { - if (use_minibatch) { - for (IdxT i = 0; i < current_batch_size; ++i) { - host_weight_buffer[i] = sample_weight->data_handle()[sample_indices[batch_idx + i]]; - } - raft::copy( - batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); - } else { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + batch_idx, - current_batch_size, - stream); - } + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + host_batch_buffer.data_handle(), + current_batch_size * n_features, + stream); } else { - auto batch_weights_fill_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); + raft::copy(batch_data_raw.data(), + host_batch_buffer.data_handle(), + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); } - // Create views for current batch size + // Mini-batch uses uniform weights (sampling already accounts for importance) + auto batch_weights_fill_view = + raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); + raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); + + // Create views auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( @@ -519,7 +478,7 @@ void fit(raft::resources const& handle, raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - // Compute L2 norms for batch if needed + // Compute L2 norms if needed if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::rowNorm(L2NormBatch.data_handle(), @@ -529,7 +488,7 @@ void fit(raft::resources const& handle, stream); } - // Find nearest centroid for each sample in batch + // Find nearest centroid auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = raft::make_device_vector_view( @@ -547,17 +506,14 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - // Accumulate partial sums for this batch + // Zero and accumulate for this batch + raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); + raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); + auto minClusterAndDistance_const = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - if (use_minibatch) { - // Mini-batch mode: zero batch accumulators before each batch - raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); - raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); - } - accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, @@ -565,33 +521,116 @@ void fit(raft::resources const& handle, centroid_sums.view(), cluster_counts.view()); - if (use_minibatch) { - // Mini-batch mode: update centroids immediately after each batch - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( - cluster_counts.data_handle(), n_clusters); + // Update centroids with online learning + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - minibatch_update_centroids( - handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); - } + minibatch_update_centroids( + handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + + } else { + // ========== FULL-BATCH MODE: Process ALL data per iteration ========== + raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); + raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); + + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + + // Sequential copy from host to device + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + } else { + raft::copy(batch_data_raw.data(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); + } + + // Set weights + auto batch_weights_fill_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } else { + raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); + } + + // Create views + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + // Compute L2 norms if needed + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm(L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); + } + + // Find nearest centroid + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); - // Accumulate cluster cost if checking convergence - if (params.inertia_check) { - cuvs::cluster::kmeans::detail::computeClusterCost( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, + batch_data_view, + centroids_const, minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - MathT batch_cost = clusterCostD.value(stream); - total_cost += batch_cost; - } - } // end batch loop + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Accumulate partial sums + auto minClusterAndDistance_const = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); + + // Accumulate cost if checking convergence + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + total_cost += clusterCostD.value(stream); + } + } // end batch loop - if (!use_minibatch) { - // Full-batch mode: finalize centroids after processing all batches + // Finalize centroids after processing all data auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto centroid_sums_const = raft::make_device_matrix_view( @@ -603,25 +642,23 @@ void fit(raft::resources const& handle, handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); } - // Compute squared norm of change in centroids (compare to saved old centroids) + // Compute squared norm of change in centroids auto sqrdNorm = raft::make_device_scalar(handle, MathT{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), centroids.size(), raft::sqdiff_op{}, stream, - new_centroids.data_handle(), // old centroids - centroids.data_handle()); // new centroids + new_centroids.data_handle(), + centroids.data_handle()); MathT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); // Check convergence bool done = false; - if (params.inertia_check) { - if (n_iter[0] > 1) { - MathT delta = total_cost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } + if (!use_minibatch && params.inertia_check && n_iter[0] > 1) { + MathT delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; priorClusteringCost = total_cost; } @@ -634,19 +671,135 @@ void fit(raft::resources const& handle, } } // end iteration loop - // Compute final inertia by processing all data once more + // Compute final inertia only if requested (requires another full pass over the data) + if (params.final_inertia_check) { + inertia[0] = 0; + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); + + if constexpr (std::is_same_v) { + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + } else { + raft::copy(batch_data_raw.data(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + raft::linalg::unaryOp(batch_data.data_handle(), + batch_data_raw.data(), + current_batch_size * n_features, + mapping_op, + stream); + } + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm(L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); + } + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + + inertia[0] += clusterCostD.value(stream); + } + RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); + } else { + inertia[0] = 0; + RAFT_LOG_DEBUG("KMeans batched: Completed (inertia computation skipped)"); + } +} + +/** + * @brief Predict cluster labels for host data using batched processing. + * + * @tparam T Input data type (float, double, uint8_t, int8_t, half) + * @tparam MathT Computation/centroid type (typically float) + * @tparam IdxT Index type (int, int64_t) + * @tparam MappingOpT Mapping operator (T -> MathT) + */ +template +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia, + MappingOpT mapping_op) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); + + // Allocate device buffers + rmm::device_uvector batch_data_raw(0, stream); + if constexpr (!std::is_same_v) { + batch_data_raw.resize(batch_size * n_features, stream); + } + + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_labels = raft::make_device_vector(handle, batch_size); + inertia[0] = 0; - for (IdxT offset = 0; offset < n_samples; offset += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - offset); + // Process all data in batches + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + + // Copy batch data from host to device if constexpr (std::is_same_v) { raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, + X.data_handle() + batch_idx * n_features, current_batch_size * n_features, stream); } else { raft::copy(batch_data_raw.data(), - X.data_handle() + offset * n_features, + X.data_handle() + batch_idx * n_features, current_batch_size * n_features, stream); raft::linalg::unaryOp(batch_data.data_handle(), @@ -656,50 +809,86 @@ void fit(raft::resources const& handle, stream); } - auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + // Handle weights + std::optional> batch_weight_view; + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + batch_weight_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); } - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); + // Create views for current batch + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_labels_view = + raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + // Call regular predict on this batch + MathT batch_inertia = 0; + cuvs::cluster::kmeans::detail::kmeans_predict( handle, + params, batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); + batch_weight_view, + centroids, + batch_labels_view, + normalize_weight, + raft::make_host_scalar_view(&batch_inertia)); + + // Copy labels back to host + raft::copy( + labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); - inertia[0] += clusterCostD.value(stream); + inertia[0] += batch_inertia; } - RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); + raft::resource::sync_stream(handle, stream); +} + +/** + * @brief Fit k-means and predict cluster labels using batched processing. + */ +template +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + MappingOpT mapping_op) +{ + // First fit the model + MathT fit_inertia = 0; + fit(handle, + params, + X, + batch_size, + sample_weight, + centroids, + raft::make_host_scalar_view(&fit_inertia), + n_iter, + mapping_op); + + // Then predict labels + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + + predict(handle, + params, + X, + batch_size, + sample_weight, + centroids_const, + labels, + false, // normalize_weight + inertia, + mapping_op); } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu index 10d7289132..6910eae722 100644 --- a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu +++ b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu @@ -54,4 +54,138 @@ void fit_batched(raft::resources const& handle, handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); } +// predict_batched implementations +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + normalize_weight, + inertia, + mapping{}); +} + +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + normalize_weight, + inertia, + mapping{}); +} + +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + normalize_weight, + inertia, + mapping{}); +} + +// fit_predict_batched implementations +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + inertia, + n_iter, + mapping{}); +} + +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + inertia, + n_iter, + mapping{}); +} + +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + inertia, + n_iter, + mapping{}); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 4986db5146..6734d32dd2 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -89,4 +89,49 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + normalize_weight, + inertia, + raft::identity_op{}); +} + +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + inertia, + n_iter, + raft::identity_op{}); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 50ee522b4a..fc7aeccb0d 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -49,7 +49,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } @@ -62,7 +62,7 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::batched::detail::fit( + cuvs::cluster::kmeans::detail::fit( handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); } @@ -89,4 +89,49 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + normalize_weight, + inertia, + raft::identity_op{}); +} + +void fit_predict_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict(handle, + params, + X, + batch_size, + sample_weight, + centroids, + labels, + inertia, + n_iter, + raft::identity_op{}); +} + } // namespace cuvs::cluster::kmeans diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index cc3b1cf4a4..509d8aacfb 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # diff --git a/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py b/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py index 3c52be00d8..af35ac453c 100644 --- a/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py +++ b/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py @@ -1,9 +1,13 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 +import gzip import os +import shutil +import struct import subprocess +import tarfile import click import h5py @@ -30,6 +34,266 @@ def download_dataset(url, path): f.write(chunk) +def download_with_wget(url, path): + """Download file using wget (better for large FTP files).""" + if not os.path.exists(path): + print(f"downloading {url} -> {path}...") + subprocess.run(["wget", "-O", path, url], check=True) + + +def read_bvecs(filename, n_vectors=None): + """ + Read .bvecs file format from TEXMEX. + Format: each vector is [dim (4 bytes int)] [dim bytes uint8 data] + Returns float32 array. + """ + print(f"Reading {filename}...") + with open(filename, "rb") as f: + # Read dimension from first vector + dim = struct.unpack("i", f.read(4))[0] + f.seek(0) + + # Calculate number of vectors + f.seek(0, 2) # Seek to end + file_size = f.tell() + record_size = 4 + dim + total_vectors = file_size // record_size + f.seek(0) + + if n_vectors is None: + n_vectors = total_vectors + else: + n_vectors = min(n_vectors, total_vectors) + + print(f" Reading {n_vectors:,} vectors of dimension {dim}") + + data = np.zeros((n_vectors, dim), dtype=np.float32) + for i in range(n_vectors): + f.read(4) # Skip dimension + vec = np.frombuffer(f.read(dim), dtype=np.uint8) + data[i] = vec.astype(np.float32) + + if (i + 1) % 10_000_000 == 0: + print(f" Loaded {i + 1:,} / {n_vectors:,} vectors...") + + return data + + +def read_ivecs(filename, n_vectors=None): + """ + Read .ivecs file format (groundtruth neighbors). + Format: each vector is [dim (4 bytes int)] [dim int32 values] + """ + print(f"Reading {filename}...") + with open(filename, "rb") as f: + dim = struct.unpack("i", f.read(4))[0] + f.seek(0) + + f.seek(0, 2) + file_size = f.tell() + record_size = 4 + dim * 4 + total_vectors = file_size // record_size + f.seek(0) + + if n_vectors is None: + n_vectors = total_vectors + else: + n_vectors = min(n_vectors, total_vectors) + + print(f" Reading {n_vectors:,} vectors of dimension {dim}") + + data = np.zeros((n_vectors, dim), dtype=np.int32) + for i in range(n_vectors): + d = struct.unpack("i", f.read(4))[0] + data[i] = np.frombuffer(f.read(dim * 4), dtype=np.int32) + + return data + + +def read_fvecs(filename, n_vectors=None): + """ + Read .fvecs file format. + Format: each vector is [dim (4 bytes int)] [dim float32 values] + """ + print(f"Reading {filename}...") + with open(filename, "rb") as f: + dim = struct.unpack("i", f.read(4))[0] + f.seek(0) + + f.seek(0, 2) + file_size = f.tell() + record_size = 4 + dim * 4 + total_vectors = file_size // record_size + f.seek(0) + + if n_vectors is None: + n_vectors = total_vectors + else: + n_vectors = min(n_vectors, total_vectors) + + print(f" Reading {n_vectors:,} vectors of dimension {dim}") + + data = np.zeros((n_vectors, dim), dtype=np.float32) + for i in range(n_vectors): + f.read(4) # Skip dimension + data[i] = np.frombuffer(f.read(dim * 4), dtype=np.float32) + + if (i + 1) % 10_000_000 == 0: + print(f" Loaded {i + 1:,} / {n_vectors:,} vectors...") + + return data + + +def write_fbin(filename, data): + """Write data in .fbin format (used by cuVS benchmarks).""" + print(f"Writing {filename}...") + n, d = data.shape + with open(filename, "wb") as f: + f.write(struct.pack("i", n)) + f.write(struct.pack("i", d)) + data.astype(np.float32).tofile(f) + print(f" Wrote {n:,} x {d} float32 array") + + +def write_ibin(filename, data): + """Write data in .ibin format (used by cuVS benchmarks for groundtruth).""" + print(f"Writing {filename}...") + n, d = data.shape + with open(filename, "wb") as f: + f.write(struct.pack("i", n)) + f.write(struct.pack("i", d)) + data.astype(np.int32).tofile(f) + print(f" Wrote {n:,} x {d} int32 array") + + +def download_sift1b(ann_bench_data_path, n_base_vectors=None): + """ + Download and convert SIFT1B dataset from TEXMEX corpus. + http://corpus-texmex.irisa.fr/ + + The dataset contains: + - bigann_base.bvecs: 1B base vectors (128-dim uint8) + - bigann_query.bvecs: 10K query vectors + - bigann_learn.bvecs: 100M learning vectors + - bigann_gnd.tar.gz: Groundtruth for various subset sizes + """ + base_url = "ftp://ftp.irisa.fr/local/texmex/corpus" + + # Create output directory + if n_base_vectors is None: + output_dir = os.path.join(ann_bench_data_path, "sift-1B") + else: + size_suffix = f"{n_base_vectors // 1_000_000}M" + output_dir = os.path.join(ann_bench_data_path, f"sift-{size_suffix}") + + os.makedirs(output_dir, exist_ok=True) + + # Temporary directory for downloads + tmp_dir = os.path.join(ann_bench_data_path, "tmp_sift1b") + os.makedirs(tmp_dir, exist_ok=True) + + try: + # Download and process base vectors + base_gz = os.path.join(tmp_dir, "bigann_base.bvecs.gz") + base_file = os.path.join(tmp_dir, "bigann_base.bvecs") + + if not os.path.exists(os.path.join(output_dir, "base.fbin")): + if not os.path.exists(base_file): + if not os.path.exists(base_gz): + download_with_wget( + f"{base_url}/bigann_base.bvecs.gz", base_gz + ) + print(f"Decompressing {base_gz}...") + with gzip.open(base_gz, "rb") as f_in: + with open(base_file, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + base_data = read_bvecs(base_file, n_base_vectors) + write_fbin(os.path.join(output_dir, "base.fbin"), base_data) + del base_data + + # Download and process query vectors + query_gz = os.path.join(tmp_dir, "bigann_query.bvecs.gz") + query_file = os.path.join(tmp_dir, "bigann_query.bvecs") + + if not os.path.exists(os.path.join(output_dir, "query.fbin")): + if not os.path.exists(query_file): + if not os.path.exists(query_gz): + download_with_wget( + f"{base_url}/bigann_query.bvecs.gz", query_gz + ) + print(f"Decompressing {query_gz}...") + with gzip.open(query_gz, "rb") as f_in: + with open(query_file, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + query_data = read_bvecs(query_file) + write_fbin(os.path.join(output_dir, "query.fbin"), query_data) + del query_data + + # Download and process groundtruth + gnd_tar = os.path.join(tmp_dir, "bigann_gnd.tar.gz") + + if not os.path.exists( + os.path.join(output_dir, "groundtruth.neighbors.ibin") + ): + if not os.path.exists(gnd_tar): + download_with_wget(f"{base_url}/bigann_gnd.tar.gz", gnd_tar) + + print(f"Extracting {gnd_tar}...") + with tarfile.open(gnd_tar, "r:gz") as tar: + tar.extractall(tmp_dir) + + # Choose appropriate groundtruth file based on n_base_vectors + if n_base_vectors is None: + gnd_file = os.path.join(tmp_dir, "gnd", "idx_1000M.ivecs") + else: + # Find closest available groundtruth + sizes = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] + target_m = n_base_vectors // 1_000_000 + closest = min( + sizes, + key=lambda x: abs(x - target_m) + if x >= target_m + else float("inf"), + ) + if closest > target_m: + closest = max( + [s for s in sizes if s <= target_m], default=sizes[0] + ) + gnd_file = os.path.join( + tmp_dir, "gnd", f"idx_{closest}M.ivecs" + ) + + if os.path.exists(gnd_file): + gnd_data = read_ivecs(gnd_file) + write_ibin( + os.path.join(output_dir, "groundtruth.neighbors.ibin"), + gnd_data, + ) + else: + print(f"Warning: Groundtruth file {gnd_file} not found") + # List available files + gnd_dir = os.path.join(tmp_dir, "gnd") + if os.path.exists(gnd_dir): + print( + f"Available groundtruth files: {os.listdir(gnd_dir)}" + ) + + print(f"\nSIFT1B dataset prepared in: {output_dir}") + print("Files:") + for f in os.listdir(output_dir): + fpath = os.path.join(output_dir, f) + size_mb = os.path.getsize(fpath) / 1e6 + print(f" {f}: {size_mb:.1f} MB") + + finally: + # Optionally clean up temp files + # shutil.rmtree(tmp_dir) + print(f"\nTemp files kept in: {tmp_dir}") + print("You can delete them manually after verifying the dataset.") + + def convert_hdf5_to_fbin(path, normalize): scripts_path = os.path.dirname(os.path.realpath(__file__)) ann_bench_scripts_path = os.path.join(scripts_path, "hdf5_to_fbin.py") @@ -147,7 +411,8 @@ def get_default_dataset_path(): @click.option( "--dataset", default="glove-100-angular", - help="Dataset to download.", + help="Dataset to download. Use 'sift-1B' for TEXMEX SIFT1B dataset, " + "or 'sift-100M', 'sift-10M' etc for subsets.", ) @click.option( "--test-data-n-train", @@ -185,6 +450,13 @@ def get_default_dataset_path(): is_flag=True, help="Normalize cosine distance to inner product.", ) +@click.option( + "--n-vectors", + default=None, + type=int, + help="Number of base vectors to use (for sift-1B dataset). " + "E.g., 300000000 for 300M vectors.", +) def main( dataset, test_data_n_train, @@ -194,6 +466,7 @@ def main( test_data_output_file, dataset_path, normalize, + n_vectors, ): # Compute default dataset_path if not provided. if dataset_path is None: @@ -210,6 +483,24 @@ def main( metric="euclidean", dataset_path=dataset_path, ) + elif dataset.startswith("sift-"): + # Handle SIFT1B and subsets from TEXMEX + # Parse dataset name for size hint (e.g., sift-100M, sift-1B) + size_str = dataset.split("-")[1].upper() + if n_vectors is None: + if size_str == "1B": + n_vectors = None # Use all 1B vectors + elif size_str.endswith("M"): + n_vectors = int(size_str[:-1]) * 1_000_000 + elif size_str.endswith("K"): + n_vectors = int(size_str[:-1]) * 1_000 + else: + try: + n_vectors = int(size_str) + except ValueError: + n_vectors = None + + download_sift1b(dataset_path, n_vectors) else: download(dataset, normalize, dataset_path) From 9d87a5f6b9016be35dc6522519cd10fbdd2c52c3 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 01:17:56 -0800 Subject: [PATCH 13/81] rm lower precision --- c/src/cluster/kmeans.cpp | 81 -------- cpp/CMakeLists.txt | 1 - cpp/include/cuvs/cluster/kmeans.hpp | 156 +------------- cpp/src/cluster/detail/kmeans_batched.cuh | 8 +- cpp/src/cluster/detail/kmeans_common.cuh | 2 - .../kmeans_fit_batched_int8_uint8_half.cu | 191 ------------------ python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 14 +- 7 files changed, 8 insertions(+), 445 deletions(-) delete mode 100644 cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 52b098ee2e..370bebdb67 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -6,8 +6,6 @@ #include #include -#include - #include #include #include @@ -246,76 +244,6 @@ void _fit_batched(cuvsResources_t res, *inertia = inertia_temp; *n_iter = n_iter_temp; } - -// Specialized version for integer/half types where X is InputT but centroids are float -template -void _fit_batched_mixed(cuvsResources_t res, - const cuvsKMeansParams& params, - DLManagedTensor* X_tensor, - IdxT batch_size, - DLManagedTensor* sample_weight_tensor, - DLManagedTensor* centroids_tensor, - double* inertia, - IdxT* n_iter) -{ - auto X = X_tensor->dl_tensor; - auto centroids = centroids_tensor->dl_tensor; - auto res_ptr = reinterpret_cast(res); - auto n_samples = static_cast(X.shape[0]); - auto n_features = static_cast(X.shape[1]); - - // X must be on host (CPU) memory - if (X.device.device_type != kDLCPU) { - RAFT_FAIL("X dataset must be on host (CPU) memory for fit_batched"); - } - - // centroids must be on device memory and float type - if (!cuvs::core::is_dlpack_device_compatible(centroids)) { - RAFT_FAIL("centroids must be on device memory"); - } - if (centroids.dtype.code != kDLFloat || centroids.dtype.bits != 32) { - RAFT_FAIL("centroids must be float32 for integer/half input types"); - } - - using device_matrix_type = raft::device_matrix_view; - - // Create host matrix view from X (InputT) - auto X_view = raft::make_host_matrix_view( - reinterpret_cast(X.data), n_samples, n_features); - - // Create device matrix view for centroids (float) - auto centroids_view = cuvs::core::from_dlpack(centroids_tensor); - - // Handle optional sample weights (float) - std::optional> sample_weight; - if (sample_weight_tensor != NULL) { - auto sw = sample_weight_tensor->dl_tensor; - if (sw.device.device_type != kDLCPU) { - RAFT_FAIL("sample_weight must be on host (CPU) memory for fit_batched"); - } - if (sw.dtype.code != kDLFloat || sw.dtype.bits != 32) { - RAFT_FAIL("sample_weight must be float32 for integer/half input types"); - } - sample_weight = raft::make_host_vector_view( - reinterpret_cast(sw.data), n_samples); - } - - float inertia_temp; - IdxT n_iter_temp; - - auto kmeans_params = convert_params(params); - cuvs::cluster::kmeans::fit_batched(*res_ptr, - kmeans_params, - X_view, - batch_size, - sample_weight, - centroids_view, - raft::make_host_scalar_view(&inertia_temp), - raft::make_host_scalar_view(&n_iter_temp)); - - *inertia = inertia_temp; - *n_iter = n_iter_temp; -} } // namespace extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) @@ -426,15 +354,6 @@ extern "C" cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); - } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { - // uint8 input -> float centroids - _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); - } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { - // int8 input -> float centroids - _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); - } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) { - // half input -> float centroids - _fit_batched_mixed(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index fe27415948..6313db71ca 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -367,7 +367,6 @@ if(NOT BUILD_CPU_ONLY) src/cluster/kmeans_balanced_fit_predict_int8.cu src/cluster/kmeans_balanced_predict_int8.cu src/cluster/kmeans_balanced_predict_uint8.cu - src/cluster/kmeans_fit_batched_int8_uint8_half.cu src/cluster/kmeans_transform_double.cu src/cluster/kmeans_transform_float.cu src/cluster/single_linkage_float.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index b8be9d1f0a..9e6fec4b67 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -330,75 +330,6 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); -/** - * @brief Find clusters with k-means algorithm using batched processing for uint8 data. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). - * @param[in] X Training instances on HOST memory (uint8). - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host, float). - * @param[inout] centroids Cluster centers on device (float, as centroids are averages). - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. - */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Find clusters with k-means algorithm using batched processing for int8 data. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). - * @param[in] X Training instances on HOST memory (int8). - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host, float). - * @param[inout] centroids Cluster centers on device (float, as centroids are averages). - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. - */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Find clusters with k-means algorithm using batched processing for half data. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). - * @param[in] X Training instances on HOST memory (half). - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host, float). - * @param[inout] centroids Cluster centers on device (float, as centroids are averages). - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. - */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - /** * @defgroup predict_batched Batched K-Means Predict * @{ @@ -418,8 +349,7 @@ void fit_batched(raft::resources const& handle, * @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features] * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] * @param[in] normalize_weight Whether to normalize sample weights. - * @param[out] inertia Sum of squared distances to nearest centroid (only if - * params.final_inertia_check is true, otherwise 0). + * @param[out] inertia Sum of squared distances to nearest centroid. */ void predict_batched(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -444,51 +374,6 @@ void predict_batched(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia); -/** - * @brief Predict cluster labels for host uint8 data using batched processing. - * - * Input data is uint8 on host, centroids are float on device. - */ -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); - -/** - * @brief Predict cluster labels for host int8 data using batched processing. - * - * Input data is int8 on host, centroids are float on device. - */ -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); - -/** - * @brief Predict cluster labels for host half data using batched processing. - * - * Input data is half on host, centroids are float on device. - */ -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); - /** * @} */ @@ -536,45 +421,6 @@ void fit_predict_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); -/** - * @brief Fit k-means and predict cluster labels for uint8 data using batched processing. - */ -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Fit k-means and predict cluster labels for int8 data using batched processing. - */ -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Fit k-means and predict cluster labels for half data using batched processing. - */ -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - /** * @} */ diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 834c2a93ac..54bdb767cc 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -310,8 +310,8 @@ void finalize_centroids(raft::resources const& handle, * This is a unified function that handles both same-type (T == MathT) and * mixed-type (T != MathT) cases, following the kmeans_balanced pattern. * - * @tparam T Input data type (float, double, uint8_t, int8_t, half) - * @tparam MathT Computation/centroid type (typically float) + * @tparam T Input data type (float, double) + * @tparam MathT Computation/centroid type (same as T for float/double) * @tparam IdxT Index type (int, int64_t) * @tparam MappingOpT Mapping operator (T -> MathT) * @@ -746,8 +746,8 @@ void fit(raft::resources const& handle, /** * @brief Predict cluster labels for host data using batched processing. * - * @tparam T Input data type (float, double, uint8_t, int8_t, half) - * @tparam MathT Computation/centroid type (typically float) + * @tparam T Input data type (float, double) + * @tparam MathT Computation/centroid type (same as T for float/double) * @tparam IdxT Index type (int, int64_t) * @tparam MappingOpT Mapping operator (T -> MathT) */ diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 62065eae7b..0db56a2f1d 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -507,7 +507,6 @@ void compute_centroid_adjustments( workspace.resize(n_samples, stream); - // Compute weighted sum of samples per cluster raft::linalg::reduce_rows_by_key(const_cast(X.data_handle()), X.extent(1), cluster_labels, @@ -519,7 +518,6 @@ void compute_centroid_adjustments( centroid_sums.data_handle(), stream); - // Compute sum of weights per cluster raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), cluster_labels, weight_per_cluster.data_handle(), diff --git a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu b/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu deleted file mode 100644 index 6910eae722..0000000000 --- a/cpp/src/cluster/kmeans_fit_batched_int8_uint8_half.cu +++ /dev/null @@ -1,191 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "../neighbors/detail/ann_utils.cuh" -#include "detail/kmeans_batched.cuh" -#include - -#include - -namespace cuvs::cluster::kmeans { - -// Use the mapping struct from ann_utils for T -> float conversion -using cuvs::spatial::knn::detail::utils::mapping; - -// Public API implementations - X is T (uint8/int8/half) but centroids are float -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); -} - -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); -} - -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, mapping{}); -} - -// predict_batched implementations -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - cuvs::cluster::kmeans::detail::predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - normalize_weight, - inertia, - mapping{}); -} - -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - cuvs::cluster::kmeans::detail::predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - normalize_weight, - inertia, - mapping{}); -} - -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - cuvs::cluster::kmeans::detail::predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - normalize_weight, - inertia, - mapping{}); -} - -// fit_predict_batched implementations -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit_predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - inertia, - n_iter, - mapping{}); -} - -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit_predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - inertia, - n_iter, - mapping{}); -} - -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cuvs::cluster::kmeans::detail::fit_predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - inertia, - n_iter, - mapping{}); -} - -} // namespace cuvs::cluster::kmeans diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 4eb6cf3c6b..dbc33ca179 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -453,7 +453,7 @@ def fit_batched( Parameters to use to fit KMeans model X : numpy array or array with __array_interface__ Input HOST memory array shape (n_samples, n_features). - Must be C-contiguous. Supported dtypes: float32, float64, uint8, int8, float16. + Must be C-contiguous. Supported dtypes: float32, float64. batch_size : int Number of samples to process per batch. Recommended: 500K-2M depending on GPU memory. @@ -497,9 +497,7 @@ def fit_batched( if not X.flags['C_CONTIGUOUS']: X = np.ascontiguousarray(X) - _check_input_array(wrap_array(X), [np.dtype('float32'), np.dtype('float64'), - np.dtype('uint8'), np.dtype('int8'), - np.dtype('float16')]) + _check_input_array(wrap_array(X), [np.dtype('float32'), np.dtype('float64')]) cdef int64_t n_samples = X.shape[0] cdef int64_t n_features = X.shape[1] @@ -515,14 +513,8 @@ def fit_batched( cdef int64_t c_batch_size = batch_size if centroids is None: - # For integer/half types, centroids are always float32 - # (centroids are averages, can't be represented as integers) - if X.dtype in (np.uint8, np.int8, np.float16): - centroids_dtype = np.float32 - else: - centroids_dtype = X.dtype centroids = device_ndarray.empty((params.n_clusters, n_features), - dtype=centroids_dtype) + dtype=X.dtype) centroids_ai = wrap_array(centroids) cdef cydlpack.DLManagedTensor* centroids_dlpack = \ From a618ed5ac96d5a085eec5326d60ef49a89799fc4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 21:22:15 -0800 Subject: [PATCH 14/81] rm unnecessary unary-ops --- c/src/cluster/kmeans.cpp | 6 +- cpp/include/cuvs/cluster/kmeans.hpp | 12 - cpp/src/cluster/detail/kmeans_batched.cuh | 366 ++++++++-------------- cpp/src/cluster/kmeans_fit_double.cu | 32 +- cpp/src/cluster/kmeans_fit_float.cu | 32 +- 5 files changed, 155 insertions(+), 293 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 370bebdb67..fb62daf966 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -208,16 +208,12 @@ void _fit_batched(cuvsResources_t res, RAFT_FAIL("centroids must be on device memory"); } - using device_matrix_type = raft::device_matrix_view; - // Create host matrix view from X auto X_view = raft::make_host_matrix_view( reinterpret_cast(X.data), n_samples, n_features); - // Create device matrix view for centroids - auto centroids_view = cuvs::core::from_dlpack(centroids_tensor); + auto centroids_view = cuvs::core::from_dlpack>(centroids_tensor); - // Handle optional sample weights std::optional> sample_weight; if (sample_weight_tensor != NULL) { auto sw = sample_weight_tensor->dl_tensor; diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 9e6fec4b67..cc4333c16d 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -191,18 +191,6 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * This version supports out-of-core computation where the dataset resides * on the host. Data is processed in batches, streaming from host to device. * - * Two centroid update modes are supported (controlled by params.update_mode): - * - * - **FullBatch** (default): Standard Lloyd's algorithm. Partial sums are - * accumulated across all batches, and centroids are updated once at the - * end of each iteration. This is mathematically equivalent to standard - * k-means and produces deterministic results. - * - * - **MiniBatch**: Mini-batch k-means. Centroids are updated incrementally - * after each randomly sampled batch using an online learning rule. This - * converges faster but may produce slightly different results each run. - * Useful for very large datasets where full passes are expensive. - * * @code{.cpp} * #include * #include diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 54bdb767cc..2495dab001 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -42,7 +41,6 @@ #include #include #include -#include #include namespace cuvs::cluster::kmeans::detail { @@ -51,16 +49,13 @@ namespace cuvs::cluster::kmeans::detail { * @brief Sample data from host to device for initialization, with optional type conversion * * @tparam T Input data type - * @tparam MathT Computation/output type * @tparam IdxT Index type - * @tparam MappingOpT Mapping operator (T -> MathT) */ -template +template void prepare_init_sample(raft::resources const& handle, raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed, - MappingOpT mapping_op) + raft::device_matrix_view X_sample, + uint64_t seed) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -83,33 +78,21 @@ void prepare_init_sample(raft::resources const& handle, n_features * sizeof(T)); } - if constexpr (std::is_same_v) { - // Same type: direct copy - raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); - } else { - // Different types: copy raw, then convert on GPU - auto raw_sample = raft::make_device_matrix(handle, n_samples_out, n_features); - raft::copy(raw_sample.data_handle(), host_sample.data(), host_sample.size(), stream); - raft::linalg::unaryOp( - X_sample.data_handle(), raw_sample.data_handle(), host_sample.size(), mapping_op, stream); - } + raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); } /** * @brief Initialize centroids using k-means++ on a sample of the host data * * @tparam T Input data type - * @tparam MathT Computation/centroid type * @tparam IdxT Index type - * @tparam MappingOpT Mapping operator (T -> MathT) */ -template +template void init_centroids_from_host_sample(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - MappingOpT mapping_op) + raft::device_matrix_view centroids, + rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -123,19 +106,19 @@ void init_centroids_from_host_sample(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); - // Sample data from host to device (with conversion if needed) - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed, mapping_op); + // Sample data from host to device + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - auto init_sample_view = raft::make_device_matrix_view( + auto init_sample_view = raft::make_device_matrix_view( init_sample.data_handle(), init_sample_size, n_features); if (params.oversampling_factor == 0) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( + cuvs::cluster::kmeans::detail::kmeansPlusPlus( handle, params, init_sample_view, centroids, workspace); } else { - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( + cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( handle, params, init_sample_view, centroids, workspace); } } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { @@ -307,34 +290,27 @@ void finalize_centroids(raft::resources const& handle, /** * @brief Main fit function for batched k-means with host data * - * This is a unified function that handles both same-type (T == MathT) and - * mixed-type (T != MathT) cases, following the kmeans_balanced pattern. - * * @tparam T Input data type (float, double) - * @tparam MathT Computation/centroid type (same as T for float/double) * @tparam IdxT Index type (int, int64_t) - * @tparam MappingOpT Mapping operator (T -> MathT) * * @param[in] handle RAFT resources handle * @param[in] params K-means parameters * @param[in] X Input data on HOST [n_samples x n_features] * @param[in] batch_size Number of samples to process per batch - * @param[in] sample_weight Optional weights per sample (on host, MathT type) + * @param[in] sample_weight Optional weights per sample (on host) * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run - * @param[in] mapping_op Mapping operator for T -> MathT conversion */ -template +template void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - MappingOpT mapping_op) + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -360,38 +336,31 @@ void fit(raft::resources const& handle, // Initialize centroids from a sample of host data if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample(handle, params, X, centroids, workspace, mapping_op); + init_centroids_from_host_sample(handle, params, X, centroids, workspace); } // Allocate device buffers - // For mixed types, we need a raw buffer for T and a converted buffer for MathT - // For same types, we only need one buffer - rmm::device_uvector batch_data_raw(0, stream); - if constexpr (!std::is_same_v) { - batch_data_raw.resize(batch_size * n_features, stream); - } - - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); - auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, batch_size); - auto L2NormBatch = raft::make_device_vector(handle, batch_size); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + raft::make_device_vector, IdxT>(handle, batch_size); + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // Accumulators for centroid computation - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto cluster_counts = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); // For mini-batch mode: track total counts for learning rate calculation - auto total_counts = raft::make_device_vector(handle, n_clusters); + auto total_counts = raft::make_device_vector(handle, n_clusters); // Host buffer for batch data auto host_batch_buffer = raft::make_host_matrix(batch_size, n_features); // Cluster cost for convergence check - rmm::device_scalar clusterCostD(stream); - MathT priorClusteringCost = 0; + rmm::device_scalar clusterCostD(stream); + T priorClusteringCost = 0; // Check update mode bool use_minibatch = @@ -417,7 +386,7 @@ void fit(raft::resources const& handle, } // Initialize total_counts once for mini-batch mode (cumulative across all iterations) - if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), MathT{0}); } + if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); } // Main iteration loop // For mini-batch: 1 iteration = 1 batch (process max_iter batches total) @@ -428,7 +397,7 @@ void fit(raft::resources const& handle, // Save old centroids for convergence check raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - MathT total_cost = 0; + T total_cost = 0; if (use_minibatch) { // ========== MINI-BATCH MODE: Process ONE batch per iteration ========== @@ -447,35 +416,23 @@ void fit(raft::resources const& handle, n_features * sizeof(T)); } - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - host_batch_buffer.data_handle(), - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - host_batch_buffer.data_handle(), - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } + raft::copy(batch_data.data_handle(), + host_batch_buffer.data_handle(), + current_batch_size * n_features, + stream); // Mini-batch uses uniform weights (sampling already accounts for importance) auto batch_weights_fill_view = - raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); + raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); + raft::matrix::fill(handle, batch_weights_fill_view, T{1}); // Create views - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); // Compute L2 norms if needed @@ -489,12 +446,12 @@ void fit(raft::resources const& handle, } // Find nearest centroid - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); + auto L2NormBatch_const = + raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -507,74 +464,62 @@ void fit(raft::resources const& handle, workspace); // Zero and accumulate for this batch - raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); - raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); auto minClusterAndDistance_const = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - cluster_counts.view()); + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); // Update centroids with online learning - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - minibatch_update_centroids( + minibatch_update_centroids( handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); } else { // ========== FULL-BATCH MODE: Process ALL data per iteration ========== - raft::matrix::fill(handle, centroid_sums.view(), MathT{0}); - raft::matrix::fill(handle, cluster_counts.view(), MathT{0}); + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); - // Sequential copy from host to device - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } + // Copy from host to device + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); // Set weights - auto batch_weights_fill_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); + auto batch_weights_fill_view = + raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + batch_idx, current_batch_size, stream); } else { - raft::matrix::fill(handle, batch_weights_fill_view, MathT{1}); + raft::matrix::fill(handle, batch_weights_fill_view, T{1}); } // Create views - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); // Compute L2 norms if needed @@ -588,12 +533,12 @@ void fit(raft::resources const& handle, } // Find nearest centroid - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -607,15 +552,15 @@ void fit(raft::resources const& handle, // Accumulate partial sums auto minClusterAndDistance_const = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - cluster_counts.view()); + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); // Accumulate cost if checking convergence if (params.inertia_check) { @@ -631,19 +576,19 @@ void fit(raft::resources const& handle, } // end batch loop // Finalize centroids after processing all data - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - finalize_centroids( + finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); } // Compute squared norm of change in centroids - auto sqrdNorm = raft::make_device_scalar(handle, MathT{0}); + auto sqrdNorm = raft::make_device_scalar(handle, T{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), centroids.size(), raft::sqdiff_op{}, @@ -651,13 +596,13 @@ void fit(raft::resources const& handle, new_centroids.data_handle(), centroids.data_handle()); - MathT sqrdNormError = 0; + T sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); // Check convergence bool done = false; if (!use_minibatch && params.inertia_check && n_iter[0] > 1) { - MathT delta = total_cost / priorClusteringCost; + T delta = total_cost / priorClusteringCost; if (delta > 1 - params.tol) done = true; priorClusteringCost = total_cost; } @@ -677,27 +622,15 @@ void fit(raft::resources const& handle, for (IdxT offset = 0; offset < n_samples; offset += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - offset); - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -709,12 +642,12 @@ void fit(raft::resources const& handle, stream); } - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); + auto L2NormBatch_const = + raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -747,21 +680,18 @@ void fit(raft::resources const& handle, * @brief Predict cluster labels for host data using batched processing. * * @tparam T Input data type (float, double) - * @tparam MathT Computation/centroid type (same as T for float/double) * @tparam IdxT Index type (int, int64_t) - * @tparam MappingOpT Mapping operator (T -> MathT) */ -template +template void predict(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, + std::optional> sample_weight, + raft::device_matrix_view centroids, raft::host_vector_view labels, bool normalize_weight, - raft::host_scalar_view inertia, - MappingOpT mapping_op) + raft::host_scalar_view inertia) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -776,13 +706,8 @@ void predict(raft::resources const& handle, RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); // Allocate device buffers - rmm::device_uvector batch_data_raw(0, stream); - if constexpr (!std::is_same_v) { - batch_data_raw.resize(batch_size * n_features, stream); - } - - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); - auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); auto batch_labels = raft::make_device_vector(handle, batch_size); inertia[0] = 0; @@ -792,43 +717,31 @@ void predict(raft::resources const& handle, IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); // Copy batch data from host to device - if constexpr (std::is_same_v) { - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - } else { - raft::copy(batch_data_raw.data(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); - raft::linalg::unaryOp(batch_data.data_handle(), - batch_data_raw.data(), - current_batch_size * n_features, - mapping_op, - stream); - } + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); // Handle weights - std::optional> batch_weight_view; + std::optional> batch_weight_view; if (sample_weight) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + batch_idx, current_batch_size, stream); - batch_weight_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); + batch_weight_view = raft::make_device_vector_view(batch_weights.data_handle(), + current_batch_size); } // Create views for current batch - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_labels_view = raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); // Call regular predict on this batch - MathT batch_inertia = 0; - cuvs::cluster::kmeans::detail::kmeans_predict( + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::kmeans_predict( handle, params, batch_data_view, @@ -851,44 +764,41 @@ void predict(raft::resources const& handle, /** * @brief Fit k-means and predict cluster labels using batched processing. */ -template +template void fit_predict(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, + std::optional> sample_weight, + raft::device_matrix_view centroids, raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - MappingOpT mapping_op) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { // First fit the model - MathT fit_inertia = 0; - fit(handle, - params, - X, - batch_size, - sample_weight, - centroids, - raft::make_host_scalar_view(&fit_inertia), - n_iter, - mapping_op); + T fit_inertia = 0; + fit(handle, + params, + X, + batch_size, + sample_weight, + centroids, + raft::make_host_scalar_view(&fit_inertia), + n_iter); // Then predict labels - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - predict(handle, - params, - X, - batch_size, - sample_weight, - centroids_const, - labels, - false, // normalize_weight - inertia, - mapping_op); + predict(handle, + params, + X, + batch_size, + sample_weight, + centroids_const, + labels, + false, // normalize_weight + inertia); } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 6734d32dd2..4edd430b3b 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -49,8 +49,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); } void fit_batched(raft::resources const& handle, @@ -62,8 +62,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); } void fit(raft::resources const& handle, @@ -100,16 +100,8 @@ void predict_batched(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia) { - cuvs::cluster::kmeans::detail::predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - normalize_weight, - inertia, - raft::identity_op{}); + cuvs::cluster::kmeans::detail::predict( + handle, params, X, batch_size, sample_weight, centroids, labels, normalize_weight, inertia); } void fit_predict_batched(raft::resources const& handle, @@ -122,16 +114,8 @@ void fit_predict_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit_predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - inertia, - n_iter, - raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit_predict( + handle, params, X, batch_size, sample_weight, centroids, labels, inertia, n_iter); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index fc7aeccb0d..e55b543f77 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -49,8 +49,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); } void fit_batched(raft::resources const& handle, @@ -62,8 +62,8 @@ void fit_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter, raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); } void fit(raft::resources const& handle, @@ -100,16 +100,8 @@ void predict_batched(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia) { - cuvs::cluster::kmeans::detail::predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - normalize_weight, - inertia, - raft::identity_op{}); + cuvs::cluster::kmeans::detail::predict( + handle, params, X, batch_size, sample_weight, centroids, labels, normalize_weight, inertia); } void fit_predict_batched(raft::resources const& handle, @@ -122,16 +114,8 @@ void fit_predict_batched(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit_predict(handle, - params, - X, - batch_size, - sample_weight, - centroids, - labels, - inertia, - n_iter, - raft::identity_op{}); + cuvs::cluster::kmeans::detail::fit_predict( + handle, params, X, batch_size, sample_weight, centroids, labels, inertia, n_iter); } } // namespace cuvs::cluster::kmeans From 3b86325c75316d7fc7ddf0872c4a585ef65b98f2 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 21:32:06 -0800 Subject: [PATCH 15/81] rm unnecessary unary-ops --- cpp/src/cluster/detail/kmeans.cuh | 30 ++---------- cpp/src/cluster/detail/kmeans_batched.cuh | 57 ----------------------- cpp/src/cluster/detail/kmeans_common.cuh | 51 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 83 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index b44d44c570..4f3a33770b 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -291,35 +291,13 @@ void update_centroids(raft::resources const& handle, weight_per_cluster, workspace); - // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where - // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the - // samples assigned to cluster-i - // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in - // cluster-i. - // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrix_vector_op( + // Divide sums by counts to get new centroids; preserve old centroids for empty clusters + cuvs::cluster::kmeans::detail::finalize_centroids( handle, raft::make_const_mdspan(new_centroids), raft::make_const_mdspan(weight_per_cluster), - new_centroids, - raft::div_checkzero_op{}); - - // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); - raft::matrix::gather_if( - const_cast(centroids.data_handle()), - static_cast(centroids.extent(1)), - static_cast(centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(weight_per_cluster.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the sum of weights in the cluster is 0 - return map.value == 0; - }, - raft::key_op{}, - raft::resource::get_cuda_stream(handle)); + centroids, + new_centroids); } // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 2495dab001..53fbdf502b 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -16,28 +16,20 @@ #include #include #include -#include #include #include #include #include #include -#include -#include #include -#include -#include #include #include #include -#include #include -#include #include -#include #include #include #include @@ -244,49 +236,6 @@ void minibatch_update_centroids(raft::resources const& handle, centroids); } -/** - * @brief Finalize centroids by dividing accumulated sums by counts - */ -template -void finalize_centroids(raft::resources const& handle, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_clusters = new_centroids.extent(0); - auto n_features = new_centroids.extent(1); - - // Copy sums to new_centroids first - raft::copy( - new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); - - // Divide by counts: new_centroids[i] = centroid_sums[i] / cluster_counts[i] - // When count is 0, set to 0 (will be fixed below) - raft::linalg::matrix_vector_op(handle, - raft::make_const_mdspan(new_centroids), - cluster_counts, - new_centroids, - raft::div_checkzero_op{}); - - // Copy old centroids to new centroids where cluster_counts[i] == 0 - cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); - raft::matrix::gather_if( - old_centroids.data_handle(), - static_cast(old_centroids.extent(1)), - static_cast(old_centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(cluster_counts.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { - return map.value == MathT{0}; // predicate: copy when count is 0 - }, - raft::key_op{}, - stream); -} - /** * @brief Main fit function for batched k-means with host data * @@ -362,7 +311,6 @@ void fit(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); T priorClusteringCost = 0; - // Check update mode bool use_minibatch = (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); @@ -389,8 +337,6 @@ void fit(raft::resources const& handle, if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); } // Main iteration loop - // For mini-batch: 1 iteration = 1 batch (process max_iter batches total) - // For full-batch: 1 iteration = 1 full pass over data for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); @@ -426,7 +372,6 @@ void fit(raft::resources const& handle, raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); raft::matrix::fill(handle, batch_weights_fill_view, T{1}); - // Create views auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( @@ -445,7 +390,6 @@ void fit(raft::resources const& handle, stream); } - // Find nearest centroid auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = @@ -488,7 +432,6 @@ void fit(raft::resources const& handle, handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); } else { - // ========== FULL-BATCH MODE: Process ALL data per iteration ========== raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 0db56a2f1d..2fedc17f27 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -526,4 +527,54 @@ void compute_centroid_adjustments( n_clusters, stream); } +/** + * @brief Finalize centroids by dividing accumulated sums by counts. + * + * For clusters with zero count, the old centroid is preserved. + * + * @tparam DataT Data type + * @tparam IndexT Index type + * + * @param[in] handle RAFT resources handle + * @param[in] centroid_sums Accumulated weighted sums per cluster [n_clusters x n_features] + * @param[in] cluster_counts Sum of weights per cluster [n_clusters] + * @param[in] old_centroids Previous centroids (used for empty clusters) [n_clusters x n_features] + * @param[out] new_centroids Output centroids [n_clusters x n_features] + */ +template +void finalize_centroids(raft::resources const& handle, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + // new_centroids = centroid_sums / cluster_counts (0 when count is 0) + raft::copy( + new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); + + raft::linalg::matrix_vector_op(handle, + raft::make_const_mdspan(new_centroids), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); + + // For empty clusters (count == 0), copy old centroid back + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + raft::matrix::gather_if( + old_centroids.data_handle(), + static_cast(old_centroids.extent(1)), + static_cast(old_centroids.extent(0)), + itr_wt, + itr_wt, + static_cast(cluster_counts.size()), + new_centroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { + return map.value == DataT{0}; + }, + raft::key_op{}, + stream); +} + } // namespace cuvs::cluster::kmeans::detail From f1b48357abb0601df50c252ca259f96a902ff6e5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 21:46:26 -0800 Subject: [PATCH 16/81] minibatch allocations are conditional --- cpp/src/cluster/detail/kmeans.cuh | 11 ++-- cpp/src/cluster/detail/kmeans_batched.cuh | 61 ++++++----------------- cpp/src/cluster/detail/kmeans_common.cuh | 7 ++- 3 files changed, 23 insertions(+), 56 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 4f3a33770b..2dc638e747 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -292,12 +292,11 @@ void update_centroids(raft::resources const& handle, workspace); // Divide sums by counts to get new centroids; preserve old centroids for empty clusters - cuvs::cluster::kmeans::detail::finalize_centroids( - handle, - raft::make_const_mdspan(new_centroids), - raft::make_const_mdspan(weight_per_cluster), - centroids, - new_centroids); + cuvs::cluster::kmeans::detail::finalize_centroids(handle, + raft::make_const_mdspan(new_centroids), + raft::make_const_mdspan(weight_per_cluster), + centroids, + new_centroids); } // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 53fbdf502b..13d960b44a 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -283,12 +283,14 @@ void fit(raft::resources const& handle, rmm::device_uvector workspace(0, stream); - // Initialize centroids from a sample of host data if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { init_centroids_from_host_sample(handle, params, X, centroids, workspace); } - // Allocate device buffers + bool use_minibatch = + (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); + RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); auto batch_weights = raft::make_device_vector(handle, batch_size); auto minClusterAndDistance = @@ -296,33 +298,20 @@ void fit(raft::resources const& handle, auto L2NormBatch = raft::make_device_vector(handle, batch_size); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // Accumulators for centroid computation auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto cluster_counts = raft::make_device_vector(handle, n_clusters); auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - // For mini-batch mode: track total counts for learning rate calculation - auto total_counts = raft::make_device_vector(handle, n_clusters); - - // Host buffer for batch data - auto host_batch_buffer = raft::make_host_matrix(batch_size, n_features); - - // Cluster cost for convergence check rmm::device_scalar clusterCostD(stream); T priorClusteringCost = 0; - bool use_minibatch = - (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); - - RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); - - // Random number generator for mini-batch sampling + // Mini-batch only state + auto total_counts = raft::make_device_vector(handle, use_minibatch ? n_clusters : 0); + auto host_batch_buffer = use_minibatch ? raft::make_host_matrix(batch_size, n_features) + : raft::make_host_matrix(0, n_features); + auto batch_indices = use_minibatch ? raft::make_host_vector(batch_size) + : raft::make_host_vector(0); std::mt19937 rng(params.rng_state.seed); - - // Buffer for sampled indices (only used in mini-batch mode) - std::vector batch_indices(batch_size); - - // For mini-batch: set up sampling distribution (weighted if weights provided) std::uniform_int_distribution uniform_dist(0, n_samples - 1); std::discrete_distribution weighted_dist; bool use_weighted_sampling = false; @@ -332,31 +321,26 @@ void fit(raft::resources const& handle, weighted_dist = std::discrete_distribution(weights.begin(), weights.end()); use_weighted_sampling = true; } - - // Initialize total_counts once for mini-batch mode (cumulative across all iterations) if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); } - // Main iteration loop for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - // Save old centroids for convergence check raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); T total_cost = 0; if (use_minibatch) { - // ========== MINI-BATCH MODE: Process ONE batch per iteration ========== IdxT current_batch_size = batch_size; - // Sample indices with replacement (weighted if weights provided) for (IdxT i = 0; i < current_batch_size; ++i) { - batch_indices[i] = use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); + batch_indices.data_handle()[i] = + use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); } #pragma omp parallel for for (IdxT i = 0; i < current_batch_size; ++i) { - IdxT sample_idx = batch_indices[i]; + IdxT sample_idx = batch_indices.data_handle()[i]; std::memcpy(host_batch_buffer.data_handle() + i * n_features, X.data_handle() + sample_idx * n_features, n_features * sizeof(T)); @@ -367,7 +351,6 @@ void fit(raft::resources const& handle, current_batch_size * n_features, stream); - // Mini-batch uses uniform weights (sampling already accounts for importance) auto batch_weights_fill_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); raft::matrix::fill(handle, batch_weights_fill_view, T{1}); @@ -380,7 +363,6 @@ void fit(raft::resources const& handle, raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - // Compute L2 norms if needed if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::rowNorm(L2NormBatch.data_handle(), @@ -407,7 +389,6 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - // Zero and accumulate for this batch raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); @@ -422,7 +403,6 @@ void fit(raft::resources const& handle, centroid_sums.view(), cluster_counts.view()); - // Update centroids with online learning auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); auto cluster_counts_const = @@ -438,13 +418,11 @@ void fit(raft::resources const& handle, for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); - // Copy from host to device raft::copy(batch_data.data_handle(), X.data_handle() + batch_idx * n_features, current_batch_size * n_features, stream); - // Set weights auto batch_weights_fill_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight) { @@ -456,7 +434,6 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, batch_weights_fill_view, T{1}); } - // Create views auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( @@ -465,7 +442,6 @@ void fit(raft::resources const& handle, raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - // Compute L2 norms if needed if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::rowNorm(L2NormBatch.data_handle(), @@ -475,7 +451,6 @@ void fit(raft::resources const& handle, stream); } - // Find nearest centroid auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = raft::make_device_vector_view( @@ -493,7 +468,6 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - // Accumulate partial sums auto minClusterAndDistance_const = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); @@ -505,7 +479,6 @@ void fit(raft::resources const& handle, centroid_sums.view(), cluster_counts.view()); - // Accumulate cost if checking convergence if (params.inertia_check) { cuvs::cluster::kmeans::detail::computeClusterCost( handle, @@ -516,9 +489,8 @@ void fit(raft::resources const& handle, raft::add_op{}); total_cost += clusterCostD.value(stream); } - } // end batch loop + } - // Finalize centroids after processing all data auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto centroid_sums_const = raft::make_device_matrix_view( @@ -530,7 +502,6 @@ void fit(raft::resources const& handle, handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); } - // Compute squared norm of change in centroids auto sqrdNorm = raft::make_device_scalar(handle, T{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), centroids.size(), @@ -542,7 +513,6 @@ void fit(raft::resources const& handle, T sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); - // Check convergence bool done = false; if (!use_minibatch && params.inertia_check && n_iter[0] > 1) { T delta = total_cost / priorClusteringCost; @@ -557,9 +527,8 @@ void fit(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); break; } - } // end iteration loop + } - // Compute final inertia only if requested (requires another full pass over the data) if (params.final_inertia_check) { inertia[0] = 0; for (IdxT offset = 0; offset < n_samples; offset += batch_size) { diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 2fedc17f27..4a4f8b72b5 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -538,7 +538,8 @@ void compute_centroid_adjustments( * @param[in] handle RAFT resources handle * @param[in] centroid_sums Accumulated weighted sums per cluster [n_clusters x n_features] * @param[in] cluster_counts Sum of weights per cluster [n_clusters] - * @param[in] old_centroids Previous centroids (used for empty clusters) [n_clusters x n_features] + * @param[in] old_centroids Previous centroids (used for empty clusters) [n_clusters x + * n_features] * @param[out] new_centroids Output centroids [n_clusters x n_features] */ template @@ -570,9 +571,7 @@ void finalize_centroids(raft::resources const& handle, itr_wt, static_cast(cluster_counts.size()), new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { - return map.value == DataT{0}; - }, + [=] __device__(raft::KeyValuePair map) { return map.value == DataT{0}; }, raft::key_op{}, stream); } From d2d3f4b4ea152550d4426b38d082b5b20cf228e8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 21:56:58 -0800 Subject: [PATCH 17/81] cleanup extraneous docs --- cpp/src/cluster/detail/kmeans_batched.cuh | 30 +++-------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 13d960b44a..a0d1295454 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -59,9 +59,7 @@ void prepare_init_sample(raft::resources const& handle, std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), gen); - // Sample raw T data to host buffer std::vector host_sample(n_samples_out * n_features); - #pragma omp parallel for for (IdxT i = 0; i < static_cast(n_samples_out); i++) { IdxT src_idx = indices[i]; @@ -91,14 +89,11 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - // Sample size for initialization: at least 3 * n_clusters, but not more than n_samples size_t init_sample_size = std::min(static_cast(n_samples), std::max(static_cast(3 * n_clusters), static_cast(10000))); - RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); - // Sample data from host to device auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); @@ -114,10 +109,9 @@ void init_centroids_from_host_sample(raft::resources const& handle, handle, params, init_sample_view, centroids, workspace); } } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - // Just use the first n_clusters samples raft::copy(centroids.data_handle(), init_sample.data_handle(), n_clusters * n_features, stream); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - // Centroids already provided, nothing to do + // already provided } else { RAFT_FAIL("Unknown initialization method"); } @@ -162,7 +156,6 @@ void accumulate_batch_centroids( batch_counts.view(), workspace); - // Add batch results to running accumulators raft::linalg::add(centroid_sums.data_handle(), centroid_sums.data_handle(), batch_sums.data_handle(), @@ -195,15 +188,13 @@ void minibatch_update_centroids(raft::resources const& handle, auto n_clusters = centroids.extent(0); auto n_features = centroids.extent(1); - // Compute batch means: batch_mean = batch_sums / batch_counts auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); raft::linalg::matrix_vector_op( handle, batch_sums, batch_counts, batch_means.view(), raft::div_checkzero_op{}); - // Step 1: Update total_counts = total_counts + batch_counts raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); - // Step 2: Compute learning rates: lr = batch_count / total_count (after update) + // lr[k] = batch_count[k] / total_count[k] (after update) auto learning_rates = raft::make_device_vector(handle, n_clusters); raft::linalg::map(handle, learning_rates.view(), @@ -211,9 +202,7 @@ void minibatch_update_centroids(raft::resources const& handle, batch_counts, raft::make_const_mdspan(total_counts)); - // Update centroids: centroid = centroid + lr * (batch_mean - centroid) - // = (1 - lr) * centroid + lr * batch_mean - // Using matrix_vector_op to scale each row by (1 - lr), then add lr * batch_mean + // centroid = (1 - lr) * centroid + lr * batch_mean raft::linalg::matrix_vector_op( handle, raft::make_const_mdspan(centroids), @@ -221,7 +210,6 @@ void minibatch_update_centroids(raft::resources const& handle, centroids, [] __device__(MathT centroid_val, MathT lr) { return (MathT{1} - lr) * centroid_val; }); - // Add lr * batch_mean to centroids raft::linalg::matrix_vector_op( handle, raft::make_const_mdspan(batch_means.view()), @@ -229,7 +217,6 @@ void minibatch_update_centroids(raft::resources const& handle, batch_means.view(), [] __device__(MathT mean_val, MathT lr) { return lr * mean_val; }); - // centroids += lr * batch_means raft::linalg::add(handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(batch_means.view()), @@ -617,24 +604,20 @@ void predict(raft::resources const& handle, RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); - // Allocate device buffers auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); auto batch_weights = raft::make_device_vector(handle, batch_size); auto batch_labels = raft::make_device_vector(handle, batch_size); inertia[0] = 0; - // Process all data in batches for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); - // Copy batch data from host to device raft::copy(batch_data.data_handle(), X.data_handle() + batch_idx * n_features, current_batch_size * n_features, stream); - // Handle weights std::optional> batch_weight_view; if (sample_weight) { raft::copy(batch_weights.data_handle(), @@ -645,13 +628,11 @@ void predict(raft::resources const& handle, current_batch_size); } - // Create views for current batch auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_labels_view = raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); - // Call regular predict on this batch T batch_inertia = 0; cuvs::cluster::kmeans::detail::kmeans_predict( handle, @@ -663,7 +644,6 @@ void predict(raft::resources const& handle, normalize_weight, raft::make_host_scalar_view(&batch_inertia)); - // Copy labels back to host raft::copy( labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); @@ -687,7 +667,6 @@ void fit_predict(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - // First fit the model T fit_inertia = 0; fit(handle, params, @@ -698,7 +677,6 @@ void fit_predict(raft::resources const& handle, raft::make_host_scalar_view(&fit_inertia), n_iter); - // Then predict labels auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); @@ -709,7 +687,7 @@ void fit_predict(raft::resources const& handle, sample_weight, centroids_const, labels, - false, // normalize_weight + false, inertia); } From 639147a9cc74d9ce2376590f6288669852316e3b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sun, 15 Feb 2026 22:38:30 -0800 Subject: [PATCH 18/81] revert changes to get_dataset --- cpp/src/cluster/detail/kmeans_batched.cuh | 11 +- .../cuvs_bench/get_dataset/__main__.py | 295 +----------------- 2 files changed, 4 insertions(+), 302 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index a0d1295454..24d1737dae 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -680,15 +680,8 @@ void fit_predict(raft::resources const& handle, auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - predict(handle, - params, - X, - batch_size, - sample_weight, - centroids_const, - labels, - false, - inertia); + predict( + handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); } } // namespace cuvs::cluster::kmeans::detail diff --git a/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py b/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py index af35ac453c..3c52be00d8 100644 --- a/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py +++ b/python/cuvs_bench/cuvs_bench/get_dataset/__main__.py @@ -1,13 +1,9 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import gzip import os -import shutil -import struct import subprocess -import tarfile import click import h5py @@ -34,266 +30,6 @@ def download_dataset(url, path): f.write(chunk) -def download_with_wget(url, path): - """Download file using wget (better for large FTP files).""" - if not os.path.exists(path): - print(f"downloading {url} -> {path}...") - subprocess.run(["wget", "-O", path, url], check=True) - - -def read_bvecs(filename, n_vectors=None): - """ - Read .bvecs file format from TEXMEX. - Format: each vector is [dim (4 bytes int)] [dim bytes uint8 data] - Returns float32 array. - """ - print(f"Reading {filename}...") - with open(filename, "rb") as f: - # Read dimension from first vector - dim = struct.unpack("i", f.read(4))[0] - f.seek(0) - - # Calculate number of vectors - f.seek(0, 2) # Seek to end - file_size = f.tell() - record_size = 4 + dim - total_vectors = file_size // record_size - f.seek(0) - - if n_vectors is None: - n_vectors = total_vectors - else: - n_vectors = min(n_vectors, total_vectors) - - print(f" Reading {n_vectors:,} vectors of dimension {dim}") - - data = np.zeros((n_vectors, dim), dtype=np.float32) - for i in range(n_vectors): - f.read(4) # Skip dimension - vec = np.frombuffer(f.read(dim), dtype=np.uint8) - data[i] = vec.astype(np.float32) - - if (i + 1) % 10_000_000 == 0: - print(f" Loaded {i + 1:,} / {n_vectors:,} vectors...") - - return data - - -def read_ivecs(filename, n_vectors=None): - """ - Read .ivecs file format (groundtruth neighbors). - Format: each vector is [dim (4 bytes int)] [dim int32 values] - """ - print(f"Reading {filename}...") - with open(filename, "rb") as f: - dim = struct.unpack("i", f.read(4))[0] - f.seek(0) - - f.seek(0, 2) - file_size = f.tell() - record_size = 4 + dim * 4 - total_vectors = file_size // record_size - f.seek(0) - - if n_vectors is None: - n_vectors = total_vectors - else: - n_vectors = min(n_vectors, total_vectors) - - print(f" Reading {n_vectors:,} vectors of dimension {dim}") - - data = np.zeros((n_vectors, dim), dtype=np.int32) - for i in range(n_vectors): - d = struct.unpack("i", f.read(4))[0] - data[i] = np.frombuffer(f.read(dim * 4), dtype=np.int32) - - return data - - -def read_fvecs(filename, n_vectors=None): - """ - Read .fvecs file format. - Format: each vector is [dim (4 bytes int)] [dim float32 values] - """ - print(f"Reading {filename}...") - with open(filename, "rb") as f: - dim = struct.unpack("i", f.read(4))[0] - f.seek(0) - - f.seek(0, 2) - file_size = f.tell() - record_size = 4 + dim * 4 - total_vectors = file_size // record_size - f.seek(0) - - if n_vectors is None: - n_vectors = total_vectors - else: - n_vectors = min(n_vectors, total_vectors) - - print(f" Reading {n_vectors:,} vectors of dimension {dim}") - - data = np.zeros((n_vectors, dim), dtype=np.float32) - for i in range(n_vectors): - f.read(4) # Skip dimension - data[i] = np.frombuffer(f.read(dim * 4), dtype=np.float32) - - if (i + 1) % 10_000_000 == 0: - print(f" Loaded {i + 1:,} / {n_vectors:,} vectors...") - - return data - - -def write_fbin(filename, data): - """Write data in .fbin format (used by cuVS benchmarks).""" - print(f"Writing {filename}...") - n, d = data.shape - with open(filename, "wb") as f: - f.write(struct.pack("i", n)) - f.write(struct.pack("i", d)) - data.astype(np.float32).tofile(f) - print(f" Wrote {n:,} x {d} float32 array") - - -def write_ibin(filename, data): - """Write data in .ibin format (used by cuVS benchmarks for groundtruth).""" - print(f"Writing {filename}...") - n, d = data.shape - with open(filename, "wb") as f: - f.write(struct.pack("i", n)) - f.write(struct.pack("i", d)) - data.astype(np.int32).tofile(f) - print(f" Wrote {n:,} x {d} int32 array") - - -def download_sift1b(ann_bench_data_path, n_base_vectors=None): - """ - Download and convert SIFT1B dataset from TEXMEX corpus. - http://corpus-texmex.irisa.fr/ - - The dataset contains: - - bigann_base.bvecs: 1B base vectors (128-dim uint8) - - bigann_query.bvecs: 10K query vectors - - bigann_learn.bvecs: 100M learning vectors - - bigann_gnd.tar.gz: Groundtruth for various subset sizes - """ - base_url = "ftp://ftp.irisa.fr/local/texmex/corpus" - - # Create output directory - if n_base_vectors is None: - output_dir = os.path.join(ann_bench_data_path, "sift-1B") - else: - size_suffix = f"{n_base_vectors // 1_000_000}M" - output_dir = os.path.join(ann_bench_data_path, f"sift-{size_suffix}") - - os.makedirs(output_dir, exist_ok=True) - - # Temporary directory for downloads - tmp_dir = os.path.join(ann_bench_data_path, "tmp_sift1b") - os.makedirs(tmp_dir, exist_ok=True) - - try: - # Download and process base vectors - base_gz = os.path.join(tmp_dir, "bigann_base.bvecs.gz") - base_file = os.path.join(tmp_dir, "bigann_base.bvecs") - - if not os.path.exists(os.path.join(output_dir, "base.fbin")): - if not os.path.exists(base_file): - if not os.path.exists(base_gz): - download_with_wget( - f"{base_url}/bigann_base.bvecs.gz", base_gz - ) - print(f"Decompressing {base_gz}...") - with gzip.open(base_gz, "rb") as f_in: - with open(base_file, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - - base_data = read_bvecs(base_file, n_base_vectors) - write_fbin(os.path.join(output_dir, "base.fbin"), base_data) - del base_data - - # Download and process query vectors - query_gz = os.path.join(tmp_dir, "bigann_query.bvecs.gz") - query_file = os.path.join(tmp_dir, "bigann_query.bvecs") - - if not os.path.exists(os.path.join(output_dir, "query.fbin")): - if not os.path.exists(query_file): - if not os.path.exists(query_gz): - download_with_wget( - f"{base_url}/bigann_query.bvecs.gz", query_gz - ) - print(f"Decompressing {query_gz}...") - with gzip.open(query_gz, "rb") as f_in: - with open(query_file, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - - query_data = read_bvecs(query_file) - write_fbin(os.path.join(output_dir, "query.fbin"), query_data) - del query_data - - # Download and process groundtruth - gnd_tar = os.path.join(tmp_dir, "bigann_gnd.tar.gz") - - if not os.path.exists( - os.path.join(output_dir, "groundtruth.neighbors.ibin") - ): - if not os.path.exists(gnd_tar): - download_with_wget(f"{base_url}/bigann_gnd.tar.gz", gnd_tar) - - print(f"Extracting {gnd_tar}...") - with tarfile.open(gnd_tar, "r:gz") as tar: - tar.extractall(tmp_dir) - - # Choose appropriate groundtruth file based on n_base_vectors - if n_base_vectors is None: - gnd_file = os.path.join(tmp_dir, "gnd", "idx_1000M.ivecs") - else: - # Find closest available groundtruth - sizes = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] - target_m = n_base_vectors // 1_000_000 - closest = min( - sizes, - key=lambda x: abs(x - target_m) - if x >= target_m - else float("inf"), - ) - if closest > target_m: - closest = max( - [s for s in sizes if s <= target_m], default=sizes[0] - ) - gnd_file = os.path.join( - tmp_dir, "gnd", f"idx_{closest}M.ivecs" - ) - - if os.path.exists(gnd_file): - gnd_data = read_ivecs(gnd_file) - write_ibin( - os.path.join(output_dir, "groundtruth.neighbors.ibin"), - gnd_data, - ) - else: - print(f"Warning: Groundtruth file {gnd_file} not found") - # List available files - gnd_dir = os.path.join(tmp_dir, "gnd") - if os.path.exists(gnd_dir): - print( - f"Available groundtruth files: {os.listdir(gnd_dir)}" - ) - - print(f"\nSIFT1B dataset prepared in: {output_dir}") - print("Files:") - for f in os.listdir(output_dir): - fpath = os.path.join(output_dir, f) - size_mb = os.path.getsize(fpath) / 1e6 - print(f" {f}: {size_mb:.1f} MB") - - finally: - # Optionally clean up temp files - # shutil.rmtree(tmp_dir) - print(f"\nTemp files kept in: {tmp_dir}") - print("You can delete them manually after verifying the dataset.") - - def convert_hdf5_to_fbin(path, normalize): scripts_path = os.path.dirname(os.path.realpath(__file__)) ann_bench_scripts_path = os.path.join(scripts_path, "hdf5_to_fbin.py") @@ -411,8 +147,7 @@ def get_default_dataset_path(): @click.option( "--dataset", default="glove-100-angular", - help="Dataset to download. Use 'sift-1B' for TEXMEX SIFT1B dataset, " - "or 'sift-100M', 'sift-10M' etc for subsets.", + help="Dataset to download.", ) @click.option( "--test-data-n-train", @@ -450,13 +185,6 @@ def get_default_dataset_path(): is_flag=True, help="Normalize cosine distance to inner product.", ) -@click.option( - "--n-vectors", - default=None, - type=int, - help="Number of base vectors to use (for sift-1B dataset). " - "E.g., 300000000 for 300M vectors.", -) def main( dataset, test_data_n_train, @@ -466,7 +194,6 @@ def main( test_data_output_file, dataset_path, normalize, - n_vectors, ): # Compute default dataset_path if not provided. if dataset_path is None: @@ -483,24 +210,6 @@ def main( metric="euclidean", dataset_path=dataset_path, ) - elif dataset.startswith("sift-"): - # Handle SIFT1B and subsets from TEXMEX - # Parse dataset name for size hint (e.g., sift-100M, sift-1B) - size_str = dataset.split("-")[1].upper() - if n_vectors is None: - if size_str == "1B": - n_vectors = None # Use all 1B vectors - elif size_str.endswith("M"): - n_vectors = int(size_str[:-1]) * 1_000_000 - elif size_str.endswith("K"): - n_vectors = int(size_str[:-1]) * 1_000 - else: - try: - n_vectors = int(size_str) - except ValueError: - n_vectors = None - - download_sift1b(dataset_path, n_vectors) else: download(dataset, normalize, dataset_path) From 624628994b984fcc6a638fb7acb24aa0f68dbd99 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 16 Feb 2026 02:53:13 -0800 Subject: [PATCH 19/81] fix python tests --- python/cuvs/cuvs/tests/test_kmeans.py | 263 +++++++++++++++++++------- 1 file changed, 195 insertions(+), 68 deletions(-) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 509d8aacfb..8850fc5525 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -6,7 +6,13 @@ import pytest from pylibraft.common import device_ndarray -from cuvs.cluster.kmeans import KMeansParams, cluster_cost, fit, predict +from cuvs.cluster.kmeans import ( + KMeansParams, + cluster_cost, + fit, + fit_batched, + predict, +) from cuvs.distance import pairwise_distance @@ -71,81 +77,202 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10, 50]) +@pytest.mark.parametrize("n_clusters", [5, 20]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_fit_batched_fullbatch(n_rows, n_cols, n_clusters, dtype): + """ + Test that fit_batched in FullBatch mode produces centroids that reduce + inertia compared to the initial centroids. + """ + rng = np.random.default_rng(42) + X = rng.random((n_rows, n_cols)).astype(dtype) + + initial_centroids = device_ndarray(X[:n_clusters].copy()) + X_device = device_ndarray(X) + original_inertia = cluster_cost(X_device, initial_centroids) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=50, + ) + + centroids, inertia, n_iter = fit_batched( + params, X, batch_size=256, centroids=initial_centroids + ) + assert n_iter >= 1 + + fitted_inertia = cluster_cost(X_device, centroids) + assert fitted_inertia < original_inertia + + +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [8]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_fit_batched_minibatch(n_rows, n_cols, n_clusters, dtype): + """ + Test that fit_batched in MiniBatch mode converges (reduces inertia). + """ + rng = np.random.default_rng(123) + X = rng.random((n_rows, n_cols)).astype(dtype) + + initial_centroids = device_ndarray(X[:n_clusters].copy()) + X_device = device_ndarray(X) + original_inertia = cluster_cost(X_device, initial_centroids) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=200, + update_mode="mini_batch", + ) + + centroids, inertia, n_iter = fit_batched( + params, X, batch_size=128, centroids=initial_centroids + ) + assert n_iter >= 1 + + fitted_inertia = cluster_cost(X_device, centroids) + assert fitted_inertia < original_inertia + + @pytest.mark.parametrize("n_rows", [1000]) @pytest.mark.parametrize("n_cols", [10]) @pytest.mark.parametrize("n_clusters", [8]) @pytest.mark.parametrize("dtype", [np.float32]) -@pytest.mark.parametrize( - "batch_samples_list", - [ - [32, 64, 128, 256, 512], # various batch sizes - ], -) -def test_kmeans_batch_size_determinism( - n_rows, n_cols, n_clusters, dtype, batch_samples_list +def test_fit_batched_matches_fit(n_rows, n_cols, n_clusters, dtype): + """ + Test that fit_batched FullBatch produces the same centroids as regular fit + when given the same initial centroids. + """ + rng = np.random.default_rng(99) + X_host = rng.random((n_rows, n_cols)).astype(dtype) + initial_centroids_host = X_host[:n_clusters].copy() + + # Regular fit (device data) + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=20, + tol=1e-10, + ) + centroids_regular, _, _ = fit( + params, + device_ndarray(X_host), + device_ndarray(initial_centroids_host.copy()), + ) + centroids_regular = centroids_regular.copy_to_host() + + # Batched fit (host data, full batch mode) + centroids_batched, _, _ = fit_batched( + params, + X_host, + batch_size=256, + centroids=device_ndarray(initial_centroids_host.copy()), + ) + centroids_batched = centroids_batched.copy_to_host() + + assert np.allclose( + centroids_regular, centroids_batched, rtol=1e-4, atol=1e-4 + ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" + + +@pytest.mark.parametrize("n_rows", [500]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [5]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("batch_size", [64, 128, 256, 500]) +def test_fit_batched_batch_size_determinism( + n_rows, n_cols, n_clusters, dtype, batch_size ): """ - Test that different batch sizes produce identical centroids. + Test that fit_batched FullBatch produces identical centroids regardless + of batch_size, since the full dataset is accumulated before updating. + """ + rng = np.random.default_rng(77) + X = rng.random((n_rows, n_cols)).astype(dtype) + initial_centroids_host = X[:n_clusters].copy() + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=20, + tol=1e-10, + ) - When starting from the same initial centroids, the k-means algorithm - should produce identical final centroids regardless of the batch_samples - parameter. This is because the accumulated adjustments to centroids after - the entire dataset pass should be the same. + # Reference: batch_size = full dataset + centroids_ref, _, _ = fit_batched( + params, + X, + batch_size=n_rows, + centroids=device_ndarray(initial_centroids_host.copy()), + ) + centroids_ref = centroids_ref.copy_to_host() + + centroids_test, _, _ = fit_batched( + params, + X, + batch_size=batch_size, + centroids=device_ndarray(initial_centroids_host.copy()), + ) + centroids_test = centroids_test.copy_to_host() + + assert np.allclose(centroids_ref, centroids_test, rtol=1e-5, atol=1e-5), ( + f"batch_size={batch_size}: max diff=" + f"{np.max(np.abs(centroids_ref - centroids_test))}" + ) + + +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [8]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_fit_batched_auto_init(n_rows, n_cols, n_clusters, dtype): """ - # Use fixed seed for reproducibility - rng = np.random.default_rng(42) + Test fit_batched without providing initial centroids (auto-initialization). + """ + rng = np.random.default_rng(55) + X = rng.random((n_rows, n_cols)).astype(dtype) - # Generate random data - X_host = rng.random((n_rows, n_cols)).astype(dtype) - X = device_ndarray(X_host) + params = KMeansParams(n_clusters=n_clusters, max_iter=50) - # Generate fixed initial centroids (using first n_clusters rows) - initial_centroids_host = X_host[:n_clusters].copy() + centroids, inertia, n_iter = fit_batched(params, X, batch_size=256) + assert centroids.shape == (n_clusters, n_cols) + assert n_iter >= 1 + + +@pytest.mark.parametrize("n_rows", [500]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [5]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_fit_batched_with_sample_weights(n_rows, n_cols, n_clusters, dtype): + """ + Test that fit_batched accepts and runs with sample weights. + """ + rng = np.random.default_rng(66) + X = rng.random((n_rows, n_cols)).astype(dtype) + weights = np.ones(n_rows, dtype=dtype) + + initial_centroids = device_ndarray(X[:n_clusters].copy()) + X_device = device_ndarray(X) + original_inertia = cluster_cost(X_device, initial_centroids) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=50, + ) + + centroids, inertia, n_iter = fit_batched( + params, + X, + batch_size=128, + centroids=initial_centroids, + sample_weights=weights, + ) - # Store results from each batch size - results = [] - - for batch_samples in batch_samples_list: - # Create fresh copy of initial centroids for each run - centroids = device_ndarray(initial_centroids_host.copy()) - - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", # Use provided centroids - max_iter=100, - tol=1e-10, # Very small tolerance to ensure convergence - batch_samples=batch_samples, - ) - - centroids_out, inertia, n_iter = fit(params, X, centroids) - results.append( - { - "batch_samples": batch_samples, - "centroids": centroids_out.copy_to_host(), - "inertia": inertia, - "n_iter": n_iter, - } - ) - - # Compare all results against the first one - reference = results[0] - for result in results[1:]: - # Centroids should be identical (or very close due to float precision) - assert np.allclose( - reference["centroids"], - result["centroids"], - rtol=1e-5, - atol=1e-5, - ), ( - f"Centroids differ between batch_samples=" - f"{reference['batch_samples']} and {result['batch_samples']}" - ) - - # Inertia should also be identical - assert np.allclose( - reference["inertia"], result["inertia"], rtol=1e-5, atol=1e-5 - ), ( - f"Inertia differs between batch_samples=" - f"{reference['batch_samples']} and {result['batch_samples']}: " - f"{reference['inertia']} vs {result['inertia']}" - ) + fitted_inertia = cluster_cost(X_device, centroids) + assert fitted_inertia < original_inertia From b58076004ba396fbb2048aab8722d20e014ab8d8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 17 Feb 2026 03:35:26 -0800 Subject: [PATCH 20/81] address sklearn inconsistency --- cpp/src/cluster/detail/kmeans_batched.cuh | 8 +- python/cuvs/cuvs/tests/test_kmeans.py | 247 +++++++--------------- 2 files changed, 78 insertions(+), 177 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 24d1737dae..d3a3cfde06 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -308,9 +308,13 @@ void fit(raft::resources const& handle, weighted_dist = std::discrete_distribution(weights.begin(), weights.end()); use_weighted_sampling = true; } - if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); } + IdxT n_steps = params.max_iter; + if (use_minibatch) { + raft::matrix::fill(handle, total_counts.view(), T{0}); + n_steps = (params.max_iter * n_samples) / batch_size; + } - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + for (n_iter[0] = 1; n_iter[0] <= n_steps; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 8850fc5525..f354e51943 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -15,6 +15,8 @@ ) from cuvs.distance import pairwise_distance +from sklearn.cluster import MiniBatchKMeans + @pytest.mark.parametrize("n_rows", [100]) @pytest.mark.parametrize("n_cols", [5, 25]) @@ -77,202 +79,97 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) -@pytest.mark.parametrize("n_rows", [1000]) -@pytest.mark.parametrize("n_cols", [10, 50]) -@pytest.mark.parametrize("n_clusters", [5, 20]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_fit_batched_fullbatch(n_rows, n_cols, n_clusters, dtype): - """ - Test that fit_batched in FullBatch mode produces centroids that reduce - inertia compared to the initial centroids. - """ - rng = np.random.default_rng(42) - X = rng.random((n_rows, n_cols)).astype(dtype) - - initial_centroids = device_ndarray(X[:n_clusters].copy()) - X_device = device_ndarray(X) - original_inertia = cluster_cost(X_device, initial_centroids) - - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", - max_iter=50, - ) - - centroids, inertia, n_iter = fit_batched( - params, X, batch_size=256, centroids=initial_centroids - ) - assert n_iter >= 1 - - fitted_inertia = cluster_cost(X_device, centroids) - assert fitted_inertia < original_inertia - - -@pytest.mark.parametrize("n_rows", [1000]) -@pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [8]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_fit_batched_minibatch(n_rows, n_cols, n_clusters, dtype): - """ - Test that fit_batched in MiniBatch mode converges (reduces inertia). - """ - rng = np.random.default_rng(123) - X = rng.random((n_rows, n_cols)).astype(dtype) - - initial_centroids = device_ndarray(X[:n_clusters].copy()) - X_device = device_ndarray(X) - original_inertia = cluster_cost(X_device, initial_centroids) - - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", - max_iter=200, - update_mode="mini_batch", - ) - - centroids, inertia, n_iter = fit_batched( - params, X, batch_size=128, centroids=initial_centroids - ) - assert n_iter >= 1 - - fitted_inertia = cluster_cost(X_device, centroids) - assert fitted_inertia < original_inertia +# @pytest.mark.parametrize("n_rows", [1000]) +# @pytest.mark.parametrize("n_cols", [10]) +# @pytest.mark.parametrize("n_clusters", [8]) +# @pytest.mark.parametrize("dtype", [np.float32]) +# @pytest.mark.parametrize("update_mode", ["full_batch", "mini_batch"]) +# def test_fit_batched_matches_fit(n_rows, n_cols, n_clusters, dtype, update_mode): +# """ +# Test that fit_batched FullBatch produces the same centroids as regular fit +# when given the same initial centroids. +# """ +# rng = np.random.default_rng(99) +# X_host = rng.random((n_rows, n_cols)).astype(dtype) +# initial_centroids_host = X_host[:n_clusters].copy() + +# # Regular fit (device data) +# params = KMeansParams( +# n_clusters=n_clusters, +# init_method="Array", +# max_iter=20, +# tol=1e-10, +# update_mode=update_mode, +# ) +# centroids_regular, _, _ = fit( +# params, +# device_ndarray(X_host), +# device_ndarray(initial_centroids_host.copy()), +# ) +# centroids_regular = centroids_regular.copy_to_host() + +# # Batched fit (host data, full batch mode) +# centroids_batched, _, _ = fit_batched( +# params, +# X_host, +# batch_size=1000, +# centroids=device_ndarray(initial_centroids_host.copy()), +# ) +# centroids_batched = centroids_batched.copy_to_host() + +# assert np.allclose( +# centroids_regular, centroids_batched, rtol=0.05, atol=0.05 +# ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" @pytest.mark.parametrize("n_rows", [1000]) @pytest.mark.parametrize("n_cols", [10]) @pytest.mark.parametrize("n_clusters", [8]) @pytest.mark.parametrize("dtype", [np.float32]) -def test_fit_batched_matches_fit(n_rows, n_cols, n_clusters, dtype): +def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): """ - Test that fit_batched FullBatch produces the same centroids as regular fit - when given the same initial centroids. + Test that fit_batched matches sklearn's KMeans implementation. """ rng = np.random.default_rng(99) X_host = rng.random((n_rows, n_cols)).astype(dtype) initial_centroids_host = X_host[:n_clusters].copy() - # Regular fit (device data) - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", - max_iter=20, - tol=1e-10, - ) - centroids_regular, _, _ = fit( - params, - device_ndarray(X_host), - device_ndarray(initial_centroids_host.copy()), - ) - centroids_regular = centroids_regular.copy_to_host() - - # Batched fit (host data, full batch mode) - centroids_batched, _, _ = fit_batched( - params, - X_host, + # Sklearn fit + kmeans = MiniBatchKMeans( + n_clusters=8, + init=initial_centroids_host, + max_iter=200, + verbose=0, + random_state=None, + tol=0.0, + max_no_improvement=10, + init_size=None, + n_init="auto", + reassignment_ratio=0.01, batch_size=256, - centroids=device_ndarray(initial_centroids_host.copy()), ) - centroids_batched = centroids_batched.copy_to_host() - - assert np.allclose( - centroids_regular, centroids_batched, rtol=1e-4, atol=1e-4 - ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" + # kmeans = KMeans(n_clusters=8, init=initial_centroids_host, n_init='auto', max_iter=20, tol=0.0001, verbose=0, random_state=None, copy_x=True, algorithm='lloyd') + kmeans.fit(X_host) + centroids_sklearn = kmeans.cluster_centers_ -@pytest.mark.parametrize("n_rows", [500]) -@pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [5]) -@pytest.mark.parametrize("dtype", [np.float32]) -@pytest.mark.parametrize("batch_size", [64, 128, 256, 500]) -def test_fit_batched_batch_size_determinism( - n_rows, n_cols, n_clusters, dtype, batch_size -): - """ - Test that fit_batched FullBatch produces identical centroids regardless - of batch_size, since the full dataset is accumulated before updating. - """ - rng = np.random.default_rng(77) - X = rng.random((n_rows, n_cols)).astype(dtype) - initial_centroids_host = X[:n_clusters].copy() - + # cuvs fit params = KMeansParams( n_clusters=n_clusters, init_method="Array", - max_iter=20, - tol=1e-10, - ) - - # Reference: batch_size = full dataset - centroids_ref, _, _ = fit_batched( - params, - X, - batch_size=n_rows, - centroids=device_ndarray(initial_centroids_host.copy()), + max_iter=200, + tol=1e-4, + update_mode="mini_batch", ) - centroids_ref = centroids_ref.copy_to_host() - - centroids_test, _, _ = fit_batched( + centroids_cuvs, _, _ = fit_batched( params, - X, - batch_size=batch_size, + X_host, + batch_size=256, centroids=device_ndarray(initial_centroids_host.copy()), ) - centroids_test = centroids_test.copy_to_host() - - assert np.allclose(centroids_ref, centroids_test, rtol=1e-5, atol=1e-5), ( - f"batch_size={batch_size}: max diff=" - f"{np.max(np.abs(centroids_ref - centroids_test))}" - ) - + centroids_cuvs = centroids_cuvs.copy_to_host() + print(centroids_cuvs) -@pytest.mark.parametrize("n_rows", [1000]) -@pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [8]) -@pytest.mark.parametrize("dtype", [np.float32]) -def test_fit_batched_auto_init(n_rows, n_cols, n_clusters, dtype): - """ - Test fit_batched without providing initial centroids (auto-initialization). - """ - rng = np.random.default_rng(55) - X = rng.random((n_rows, n_cols)).astype(dtype) - - params = KMeansParams(n_clusters=n_clusters, max_iter=50) - - centroids, inertia, n_iter = fit_batched(params, X, batch_size=256) - assert centroids.shape == (n_clusters, n_cols) - assert n_iter >= 1 - - -@pytest.mark.parametrize("n_rows", [500]) -@pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [5]) -@pytest.mark.parametrize("dtype", [np.float32]) -def test_fit_batched_with_sample_weights(n_rows, n_cols, n_clusters, dtype): - """ - Test that fit_batched accepts and runs with sample weights. - """ - rng = np.random.default_rng(66) - X = rng.random((n_rows, n_cols)).astype(dtype) - weights = np.ones(n_rows, dtype=dtype) - - initial_centroids = device_ndarray(X[:n_clusters].copy()) - X_device = device_ndarray(X) - original_inertia = cluster_cost(X_device, initial_centroids) - - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", - max_iter=50, - ) - - centroids, inertia, n_iter = fit_batched( - params, - X, - batch_size=128, - centroids=initial_centroids, - sample_weights=weights, - ) - - fitted_inertia = cluster_cost(X_device, centroids) - assert fitted_inertia < original_inertia + assert np.allclose( + centroids_sklearn, centroids_cuvs, rtol=0.3, atol=0.3 + ), f"max diff: {np.max(np.abs(centroids_sklearn - centroids_cuvs))}" From 491c900aac5f7e40f7ed83dcad0d3c56005bce63 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 17 Feb 2026 04:10:04 -0800 Subject: [PATCH 21/81] fix call to finalize_centroids --- cpp/src/cluster/detail/kmeans_batched.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index d3a3cfde06..ce08fe5553 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -490,7 +490,9 @@ void fit(raft::resources const& handle, raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); finalize_centroids( - handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); } auto sqrdNorm = raft::make_device_scalar(handle, T{0}); From 6886fb7f880ff4d208bf96b9d5270a2ce0e42109 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 17 Feb 2026 04:36:32 -0800 Subject: [PATCH 22/81] bug fixes and python tests --- cpp/src/cluster/detail/kmeans_batched.cuh | 6 +- python/cuvs/cuvs/tests/test_kmeans.py | 81 +++++++++++------------ 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index ce08fe5553..7e67e261a4 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -491,8 +491,6 @@ void fit(raft::resources const& handle, finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); - - raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); } auto sqrdNorm = raft::make_device_scalar(handle, T{0}); @@ -503,6 +501,10 @@ void fit(raft::resources const& handle, new_centroids.data_handle(), centroids.data_handle()); + if (!use_minibatch) { + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); + } + T sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index f354e51943..1930c1a964 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -79,47 +79,46 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) -# @pytest.mark.parametrize("n_rows", [1000]) -# @pytest.mark.parametrize("n_cols", [10]) -# @pytest.mark.parametrize("n_clusters", [8]) -# @pytest.mark.parametrize("dtype", [np.float32]) -# @pytest.mark.parametrize("update_mode", ["full_batch", "mini_batch"]) -# def test_fit_batched_matches_fit(n_rows, n_cols, n_clusters, dtype, update_mode): -# """ -# Test that fit_batched FullBatch produces the same centroids as regular fit -# when given the same initial centroids. -# """ -# rng = np.random.default_rng(99) -# X_host = rng.random((n_rows, n_cols)).astype(dtype) -# initial_centroids_host = X_host[:n_clusters].copy() - -# # Regular fit (device data) -# params = KMeansParams( -# n_clusters=n_clusters, -# init_method="Array", -# max_iter=20, -# tol=1e-10, -# update_mode=update_mode, -# ) -# centroids_regular, _, _ = fit( -# params, -# device_ndarray(X_host), -# device_ndarray(initial_centroids_host.copy()), -# ) -# centroids_regular = centroids_regular.copy_to_host() - -# # Batched fit (host data, full batch mode) -# centroids_batched, _, _ = fit_batched( -# params, -# X_host, -# batch_size=1000, -# centroids=device_ndarray(initial_centroids_host.copy()), -# ) -# centroids_batched = centroids_batched.copy_to_host() - -# assert np.allclose( -# centroids_regular, centroids_batched, rtol=0.05, atol=0.05 -# ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" +@pytest.mark.parametrize("n_rows", [1000, 5000]) +@pytest.mark.parametrize("n_cols", [10, 100]) +@pytest.mark.parametrize("n_clusters", [8, 16]) +@pytest.mark.parametrize("batch_size", [100, 500]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_fit_batched_matches_fit( + n_rows, n_cols, n_clusters, batch_size, dtype +): + """ + Test that fit_batched FullBatch produces the same centroids as regular fit + when given the same initial centroids. + """ + rng = np.random.default_rng(99) + X_host = rng.random((n_rows, n_cols)).astype(dtype) + initial_centroids_host = X_host[:n_clusters].copy() + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=100, + tol=1e-10, + ) + centroids_regular, _, _ = fit( + params, + device_ndarray(X_host), + device_ndarray(initial_centroids_host.copy()), + ) + centroids_regular = centroids_regular.copy_to_host() + + centroids_batched, _, _ = fit_batched( + params, + X_host, + batch_size=batch_size, + centroids=device_ndarray(initial_centroids_host.copy()), + ) + centroids_batched = centroids_batched.copy_to_host() + + assert np.allclose( + centroids_regular, centroids_batched, rtol=0.05, atol=0.05 + ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" @pytest.mark.parametrize("n_rows", [1000]) From ab366f5e8b9dee137b966d429ad8ec773e794f12 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 18 Feb 2026 22:53:12 -0800 Subject: [PATCH 23/81] add early stopping criteria from sklearn --- c/include/cuvs/cluster/kmeans.h | 7 ++ c/src/cluster/kmeans.cpp | 2 + cpp/include/cuvs/cluster/kmeans.hpp | 8 ++ cpp/src/cluster/detail/kmeans_batched.cuh | 132 +++++++++++++++++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 2 + 5 files changed, 129 insertions(+), 22 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 5b157abd6d..b1c25b6dec 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -122,6 +122,13 @@ struct cuvsKMeansParams { */ bool final_inertia_check; + /** + * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia + * before early stopping. Only used when update_mode is CUVS_KMEANS_UPDATE_MINI_BATCH. + * If 0, this convergence criterion is disabled. + */ + int max_no_improvement; + /** * Whether to use hierarchical (balanced) kmeans or not */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index fb62daf966..77e9b5126e 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -29,6 +29,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; kmeans_params.final_inertia_check = params.final_inertia_check; + kmeans_params.max_no_improvement = params.max_no_improvement; kmeans_params.update_mode = static_cast(params.update_mode); return kmeans_params; @@ -260,6 +261,7 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .update_mode = static_cast(cpp_params.update_mode), .inertia_check = cpp_params.inertia_check, .final_inertia_check = cpp_params.final_inertia_check, + .max_no_improvement = cpp_params.max_no_improvement, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; }); diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index cc4333c16d..b991ed2621 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -148,6 +148,14 @@ struct params : base_params { * Default: false (skip final inertia computation for performance). */ bool final_inertia_check = false; + + /** + * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia + * before early stopping. Only used when update_mode is MiniBatch. + * If None/0, this convergence criterion is disabled. + * Default: 10 + */ + int max_no_improvement = 10; }; /** diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 7e67e261a4..16c6753f4c 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -309,6 +309,14 @@ void fit(raft::resources const& handle, use_weighted_sampling = true; } IdxT n_steps = params.max_iter; + + // Mini-batch convergence tracking + T ewa_inertia = T{0}; + T ewa_inertia_min = T{0}; + int no_improvement = 0; + bool ewa_initialized = false; + auto prev_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); n_steps = (params.max_iter * n_samples) / batch_size; @@ -363,6 +371,9 @@ void fit(raft::resources const& handle, stream); } + // Save centroids before update for convergence check + raft::copy(prev_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = @@ -380,6 +391,17 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); + // Compute batch inertia (normalized by batch_size for comparison) + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + batch_inertia = clusterCostD.value(stream) / static_cast(current_batch_size); + raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); @@ -402,6 +424,71 @@ void fit(raft::resources const& handle, minibatch_update_centroids( handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + // Compute squared difference of centers (for convergence check) + auto sqrdNorm = raft::make_device_scalar(handle, T{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + centroids.size(), + raft::sqdiff_op{}, + stream, + prev_centroids.data_handle(), + centroids.data_handle()); + T centers_squared_diff = 0; + raft::copy(¢ers_squared_diff, sqrdNorm.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + + // Skip first step (inertia from initialization) + if (n_iter[0] > 1) { + // Update Exponentially Weighted Average of inertia + T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); + alpha = std::min(alpha, T{1}); + + if (!ewa_initialized) { + ewa_inertia = batch_inertia; + ewa_inertia_min = batch_inertia; + ewa_initialized = true; + } else { + ewa_inertia = ewa_inertia * (T{1} - alpha) + batch_inertia * alpha; + } + + RAFT_LOG_DEBUG( + "KMeans minibatch step %d/%d: batch_inertia=%f, ewa_inertia=%f, centers_squared_diff=%f", + n_iter[0], + n_steps, + static_cast(batch_inertia), + static_cast(ewa_inertia), + static_cast(centers_squared_diff)); + + // Early stopping: absolute tolerance on squared change of centers + // Disabled if tol == 0.0 + if (params.tol > 0.0 && centers_squared_diff <= params.tol) { + RAFT_LOG_DEBUG( + "KMeans minibatch: Converged (small centers change) at step %d/%d", n_iter[0], n_steps); + break; + } + + // Early stopping: lack of improvement in smoothed inertia + // Disabled if max_no_improvement == 0 + if (params.max_no_improvement > 0) { + if (ewa_inertia < ewa_inertia_min) { + no_improvement = 0; + ewa_inertia_min = ewa_inertia; + } else { + no_improvement++; + } + + if (no_improvement >= params.max_no_improvement) { + RAFT_LOG_DEBUG("KMeans minibatch: Converged (lack of improvement) at step %d/%d", + n_iter[0], + n_steps); + break; + } + } + } else { + RAFT_LOG_DEBUG("KMeans minibatch step %d/%d: mean batch inertia: %f", + n_iter[0], + n_steps, + static_cast(batch_inertia)); + } } else { raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); @@ -493,34 +580,35 @@ void fit(raft::resources const& handle, handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); } - auto sqrdNorm = raft::make_device_scalar(handle, T{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - centroids.size(), - raft::sqdiff_op{}, - stream, - new_centroids.data_handle(), - centroids.data_handle()); - + // Convergence check for full-batch mode only if (!use_minibatch) { + auto sqrdNorm = raft::make_device_scalar(handle, T{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + centroids.size(), + raft::sqdiff_op{}, + stream, + new_centroids.data_handle(), + centroids.data_handle()); + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); - } - T sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + T sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); - bool done = false; - if (!use_minibatch && params.inertia_check && n_iter[0] > 1) { - T delta = total_cost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - priorClusteringCost = total_cost; - } + bool done = false; + if (params.inertia_check && n_iter[0] > 1) { + T delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + priorClusteringCost = total_cost; + } - raft::resource::sync_stream(handle, stream); - if (sqrdNormError < params.tol) done = true; + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; - if (done) { - RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); - break; + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; + } } } diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 0e20d6a709..8fbd7577ed 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -38,6 +38,8 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: int batch_centroids, cuvsKMeansCentroidUpdateMode update_mode, bool inertia_check, + bool final_inertia_check, + int max_no_improvement, bool hierarchical, int hierarchical_n_iters From dbfd1a84c6c70e760144852aaba7ecf0193210b1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 19 Feb 2026 23:33:42 -0800 Subject: [PATCH 24/81] fixes --- cpp/src/cluster/detail/kmeans_batched.cuh | 33 +++++++++------------- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 33 ++++++++++++++++++++++ 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 16c6753f4c..f62001e498 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -350,15 +350,15 @@ void fit(raft::resources const& handle, current_batch_size * n_features, stream); - auto batch_weights_fill_view = + auto batch_weights_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - raft::matrix::fill(handle, batch_weights_fill_view, T{1}); + raft::matrix::fill(handle, batch_weights_view, T{1}); auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); - auto minClusterAndDistance_view = + auto minClusterAndDistance.view() = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); @@ -383,7 +383,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance_view, + minClusterAndDistance.view(), L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -395,7 +395,7 @@ void fit(raft::resources const& handle, T batch_inertia = 0; cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance_view, + minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, @@ -406,8 +406,7 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, cluster_counts.view(), T{0}); auto minClusterAndDistance_const = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); + raft::make_const_mdspan(minClusterAndDistance.view()); accumulate_batch_centroids(handle, batch_data_view, @@ -501,7 +500,7 @@ void fit(raft::resources const& handle, current_batch_size * n_features, stream); - auto batch_weights_fill_view = + auto batch_weights_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight) { raft::copy(batch_weights.data_handle(), @@ -509,16 +508,13 @@ void fit(raft::resources const& handle, current_batch_size, stream); } else { - raft::matrix::fill(handle, batch_weights_fill_view, T{1}); + raft::matrix::fill(handle, batch_weights_view, T{1}); } auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { @@ -538,7 +534,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance_view, + minClusterAndDistance.view(), L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -547,8 +543,7 @@ void fit(raft::resources const& handle, workspace); auto minClusterAndDistance_const = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); + raft::make_const_mdspan(minClusterAndDistance.view()); accumulate_batch_centroids(handle, batch_data_view, @@ -560,7 +555,7 @@ void fit(raft::resources const& handle, if (params.inertia_check) { cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance_view, + minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, @@ -624,7 +619,7 @@ void fit(raft::resources const& handle, auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto minClusterAndDistance_view = + auto minClusterAndDistance.view() = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); @@ -646,7 +641,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance_view, + minClusterAndDistance.view(), L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -656,7 +651,7 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance_view, + minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index dbc33ca179..8e2b32c7e7 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -88,6 +88,18 @@ cdef class KMeansParams: "full_batch" : Standard Lloyd's algorithm - accumulate assignments over the entire dataset, then update centroids once per iteration. "mini_batch" : Mini-batch k-means - update centroids after each batch. + inertia_check : bool + If True, check inertia during iterations for early convergence. + final_inertia_check : bool + If True, compute the final inertia after fit_batched completes. + This requires an additional full pass over all the host data. + Only used by fit_batched(); regular fit() always computes final inertia. + Default: False (skip final inertia computation for performance). + max_no_improvement : int + Maximum number of consecutive mini-batch steps without improvement + in smoothed inertia before early stopping. Only used when update_mode + is "mini_batch". If 0, this convergence criterion is disabled. + Default: 10 (matches sklearn's default). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -113,6 +125,9 @@ cdef class KMeansParams: batch_samples=None, batch_centroids=None, update_mode=None, + inertia_check=None, + final_inertia_check=None, + max_no_improvement=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -137,6 +152,12 @@ cdef class KMeansParams: if update_mode is not None: c_mode = UPDATE_MODE_TYPES[update_mode] self.params.update_mode = c_mode + if inertia_check is not None: + self.params.inertia_check = inertia_check + if final_inertia_check is not None: + self.params.final_inertia_check = final_inertia_check + if max_no_improvement is not None: + self.params.max_no_improvement = max_no_improvement if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -185,6 +206,18 @@ cdef class KMeansParams: def update_mode(self): return UPDATE_MODE_NAMES[self.params.update_mode] + @property + def inertia_check(self): + return self.params.inertia_check + + @property + def final_inertia_check(self): + return self.params.final_inertia_check + + @property + def max_no_improvement(self): + return self.params.max_no_improvement + @property def hierarchical(self): return self.params.hierarchical From e76624f1e6290b4c685d837332ff90b3664c6b71 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 24 Feb 2026 06:27:45 -0800 Subject: [PATCH 25/81] fix test by normalizing data --- cpp/src/cluster/detail/kmeans_batched.cuh | 34 +++++++++++++---------- python/cuvs/cuvs/tests/test_kmeans.py | 11 +++++--- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index f62001e498..01b43ff906 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -142,6 +142,9 @@ void accumulate_batch_centroids( auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); + raft::matrix::fill(handle, batch_sums.view(), MathT{0}); + raft::matrix::fill(handle, batch_counts.view(), MathT{0}); + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; thrust::transform_iterator, const raft::KeyValuePair*> @@ -356,9 +359,9 @@ void fit(raft::resources const& handle, auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view_const = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); - auto minClusterAndDistance.view() = + auto minClusterAndDistance_view = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); @@ -383,7 +386,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance.view(), + minClusterAndDistance_view, L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -395,7 +398,7 @@ void fit(raft::resources const& handle, T batch_inertia = 0; cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance.view(), + minClusterAndDistance_view, workspace, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, @@ -406,12 +409,12 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, cluster_counts.view(), T{0}); auto minClusterAndDistance_const = - raft::make_const_mdspan(minClusterAndDistance.view()); + raft::make_const_mdspan(minClusterAndDistance_view); accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, - batch_weights_view, + batch_weights_view_const, centroid_sums.view(), cluster_counts.view()); @@ -492,6 +495,10 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); + // Use centroids from start of iteration for all batches (must remain constant) + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); @@ -500,7 +507,7 @@ void fit(raft::resources const& handle, current_batch_size * n_features, stream); - auto batch_weights_view = + auto batch_weights_fill_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight) { raft::copy(batch_weights.data_handle(), @@ -508,7 +515,7 @@ void fit(raft::resources const& handle, current_batch_size, stream); } else { - raft::matrix::fill(handle, batch_weights_view, T{1}); + raft::matrix::fill(handle, batch_weights_fill_view, T{1}); } auto batch_data_view = raft::make_device_matrix_view( @@ -525,8 +532,6 @@ void fit(raft::resources const& handle, stream); } - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); @@ -562,10 +567,9 @@ void fit(raft::resources const& handle, raft::add_op{}); total_cost += clusterCostD.value(stream); } + } - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); auto cluster_counts_const = @@ -619,7 +623,7 @@ void fit(raft::resources const& handle, auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto minClusterAndDistance.view() = + auto minClusterAndDistance_view = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); @@ -641,7 +645,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance.view(), + minClusterAndDistance_view, L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -651,7 +655,7 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance.view(), + minClusterAndDistance_view, workspace, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 1930c1a964..5239d2f04c 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -83,7 +83,7 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): @pytest.mark.parametrize("n_cols", [10, 100]) @pytest.mark.parametrize("n_clusters", [8, 16]) @pytest.mark.parametrize("batch_size", [100, 500]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dtype", [np.float64]) def test_fit_batched_matches_fit( n_rows, n_cols, n_clusters, batch_size, dtype ): @@ -93,6 +93,11 @@ def test_fit_batched_matches_fit( """ rng = np.random.default_rng(99) X_host = rng.random((n_rows, n_cols)).astype(dtype) + + norms = np.linalg.norm(X_host, ord=1, axis=1, keepdims=True) + norms = np.where(norms == 0, 1.0, norms) + X_host = X_host / norms + initial_centroids_host = X_host[:n_clusters].copy() params = KMeansParams( @@ -117,7 +122,7 @@ def test_fit_batched_matches_fit( centroids_batched = centroids_batched.copy_to_host() assert np.allclose( - centroids_regular, centroids_batched, rtol=0.05, atol=0.05 + centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3 ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" @@ -147,7 +152,6 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): reassignment_ratio=0.01, batch_size=256, ) - # kmeans = KMeans(n_clusters=8, init=initial_centroids_host, n_init='auto', max_iter=20, tol=0.0001, verbose=0, random_state=None, copy_x=True, algorithm='lloyd') kmeans.fit(X_host) centroids_sklearn = kmeans.cluster_centers_ @@ -167,7 +171,6 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): centroids=device_ndarray(initial_centroids_host.copy()), ) centroids_cuvs = centroids_cuvs.copy_to_host() - print(centroids_cuvs) assert np.allclose( centroids_sklearn, centroids_cuvs, rtol=0.3, atol=0.3 From aacd54314261c545a8d8d63c35abed0da5ce0d53 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 24 Feb 2026 20:02:58 -0800 Subject: [PATCH 26/81] rejection sampling --- cpp/src/cluster/detail/kmeans_batched.cuh | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 01b43ff906..c04275f3d6 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -33,6 +33,7 @@ #include #include #include +#include #include namespace cuvs::cluster::kmeans::detail { @@ -55,9 +56,19 @@ void prepare_init_sample(raft::resources const& handle, auto n_samples_out = X_sample.extent(0); std::mt19937 gen(seed); - std::vector indices(n_samples); - std::iota(indices.begin(), indices.end(), 0); - std::shuffle(indices.begin(), indices.end(), gen); + std::uniform_int_distribution dist(0, n_samples - 1); + + // Generate n_samples_out unique random indices using rejection sampling + // Since n_samples_out << n_samples, collisions are rare and this is O(n_samples_out) + std::unordered_set selected_indices; + selected_indices.reserve(n_samples_out); + + while (static_cast(selected_indices.size()) < n_samples_out) { + IdxT idx = dist(gen); + selected_indices.insert(idx); + } + + std::vector indices(selected_indices.begin(), selected_indices.end()); std::vector host_sample(n_samples_out * n_features); #pragma omp parallel for From d746af2f8b9f4858a2b0cb13851c3aab7a583b03 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 24 Feb 2026 20:20:19 -0800 Subject: [PATCH 27/81] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 58 ++++++++++------------- python/cuvs/cuvs/tests/test_kmeans.py | 6 +-- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index c04275f3d6..19806f182d 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -57,17 +58,17 @@ void prepare_init_sample(raft::resources const& handle, std::mt19937 gen(seed); std::uniform_int_distribution dist(0, n_samples - 1); - + // Generate n_samples_out unique random indices using rejection sampling // Since n_samples_out << n_samples, collisions are rare and this is O(n_samples_out) std::unordered_set selected_indices; selected_indices.reserve(n_samples_out); - + while (static_cast(selected_indices.size()) < n_samples_out) { IdxT idx = dist(gen); selected_indices.insert(idx); } - + std::vector indices(selected_indices.begin(), selected_indices.end()); std::vector host_sample(n_samples_out * n_features); @@ -419,8 +420,7 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); - auto minClusterAndDistance_const = - raft::make_const_mdspan(minClusterAndDistance_view); + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); accumulate_batch_centroids(handle, batch_data_view, @@ -510,37 +510,34 @@ void fit(raft::resources const& handle, auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, batch_size, stream); + + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); auto batch_weights_fill_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight) { raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + batch_idx, + sample_weight->data_handle() + data_batch.offset(), current_batch_size, stream); } else { raft::matrix::fill(handle, batch_weights_fill_view, T{1}); } - auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); } auto L2NormBatch_const = raft::make_device_vector_view( @@ -558,8 +555,7 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - auto minClusterAndDistance_const = - raft::make_const_mdspan(minClusterAndDistance.view()); + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); accumulate_batch_centroids(handle, batch_data_view, @@ -578,7 +574,6 @@ void fit(raft::resources const& handle, raft::add_op{}); total_cost += clusterCostD.value(stream); } - } auto centroid_sums_const = raft::make_device_matrix_view( @@ -624,27 +619,22 @@ void fit(raft::resources const& handle, if (params.final_inertia_check) { inertia[0] = 0; - for (IdxT offset = 0; offset < n_samples; offset += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - offset); + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches(X.data_handle(), n_samples, n_features, batch_size, stream); - raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); + data_batch.data(), current_batch_size, n_features); auto minClusterAndDistance_view = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); } auto centroids_const = raft::make_device_matrix_view( diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 5239d2f04c..623c448726 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -93,11 +93,11 @@ def test_fit_batched_matches_fit( """ rng = np.random.default_rng(99) X_host = rng.random((n_rows, n_cols)).astype(dtype) - + norms = np.linalg.norm(X_host, ord=1, axis=1, keepdims=True) norms = np.where(norms == 0, 1.0, norms) X_host = X_host / norms - + initial_centroids_host = X_host[:n_clusters].copy() params = KMeansParams( @@ -142,7 +142,7 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): kmeans = MiniBatchKMeans( n_clusters=8, init=initial_centroids_host, - max_iter=200, + max_iter=100, verbose=0, random_state=None, tol=0.0, From 6d01aed3d39c1c02df35fd361ba8a361fdbc1a75 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 24 Feb 2026 21:01:11 -0800 Subject: [PATCH 28/81] update test with inertia check --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 +- python/cuvs/cuvs/tests/test_kmeans.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 19806f182d..e6cfe1086d 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -7,9 +7,9 @@ #include "kmeans.cuh" #include "kmeans_common.cuh" +#include "../../neighbors/detail/ann_utils.cuh" #include #include -#include #include #include diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 623c448726..c87aaa5e28 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -155,16 +155,19 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): kmeans.fit(X_host) centroids_sklearn = kmeans.cluster_centers_ + inertia_sklearn = kmeans.inertia_ # cuvs fit params = KMeansParams( n_clusters=n_clusters, init_method="Array", - max_iter=200, + max_iter=100, tol=1e-4, update_mode="mini_batch", + final_inertia_check=True, + max_no_improvement=10, ) - centroids_cuvs, _, _ = fit_batched( + centroids_cuvs, inertia_cuvs, _ = fit_batched( params, X_host, batch_size=256, @@ -172,6 +175,12 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): ) centroids_cuvs = centroids_cuvs.copy_to_host() + # Compare centroids assert np.allclose( centroids_sklearn, centroids_cuvs, rtol=0.3, atol=0.3 ), f"max diff: {np.max(np.abs(centroids_sklearn - centroids_cuvs))}" + + inertia_diff = abs(inertia_sklearn - inertia_cuvs) + assert np.allclose( + inertia_sklearn, inertia_cuvs, rtol=0.1, atol=0.1 + ), f"inertia diff: sklearn={inertia_sklearn}, cuvs={inertia_cuvs}, diff={inertia_diff}" From 2f1189ff3113ff078c2d02243b01507796f95d57 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 24 Feb 2026 21:05:37 -0800 Subject: [PATCH 29/81] fix style --- python/cuvs/cuvs/tests/test_kmeans.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index c87aaa5e28..963105d8c6 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -175,12 +175,11 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): ) centroids_cuvs = centroids_cuvs.copy_to_host() - # Compare centroids assert np.allclose( centroids_sklearn, centroids_cuvs, rtol=0.3, atol=0.3 ), f"max diff: {np.max(np.abs(centroids_sklearn - centroids_cuvs))}" inertia_diff = abs(inertia_sklearn - inertia_cuvs) - assert np.allclose( - inertia_sklearn, inertia_cuvs, rtol=0.1, atol=0.1 - ), f"inertia diff: sklearn={inertia_sklearn}, cuvs={inertia_cuvs}, diff={inertia_diff}" + assert np.allclose(inertia_sklearn, inertia_cuvs, rtol=0.1, atol=0.1), ( + f"inertia diff: sklearn={inertia_sklearn}, cuvs={inertia_cuvs}, diff={inertia_diff}" + ) From 0862da6e317a0e062e58a378e0e82960bdb15efc Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Feb 2026 10:22:16 -0800 Subject: [PATCH 30/81] add reassignment; update minibatch params struct --- c/include/cuvs/cluster/kmeans.h | 8 + c/src/cluster/kmeans.cpp | 14 +- cpp/include/cuvs/cluster/kmeans.hpp | 72 ++++--- cpp/src/cluster/detail/kmeans_batched.cuh | 232 ++++++++++++++++++++- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 1 + python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 12 ++ 6 files changed, 302 insertions(+), 37 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index b1c25b6dec..f01395aaa4 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -129,6 +129,14 @@ struct cuvsKMeansParams { */ int max_no_improvement; + /** + * Control the fraction of the maximum number of counts for a center to be reassigned. + * Centers with count < reassignment_ratio * max(counts) are randomly reassigned to + * observations from the current batch. Only used when update_mode is CUVS_KMEANS_UPDATE_MINI_BATCH. + * If 0.0, reassignment is disabled. Default: 0.01 + */ + double reassignment_ratio; + /** * Whether to use hierarchical (balanced) kmeans or not */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 77e9b5126e..e2c9156c1a 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,9 +28,10 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; - kmeans_params.final_inertia_check = params.final_inertia_check; - kmeans_params.max_no_improvement = params.max_no_improvement; - kmeans_params.update_mode = + kmeans_params.batched.final_inertia_check = params.final_inertia_check; + kmeans_params.batched.minibatch.max_no_improvement = params.max_no_improvement; + kmeans_params.batched.minibatch.reassignment_ratio = params.reassignment_ratio; + kmeans_params.batched.update_mode = static_cast(params.update_mode); return kmeans_params; } @@ -258,10 +259,11 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, - .update_mode = static_cast(cpp_params.update_mode), + .update_mode = static_cast(cpp_params.batched.update_mode), .inertia_check = cpp_params.inertia_check, - .final_inertia_check = cpp_params.final_inertia_check, - .max_no_improvement = cpp_params.max_no_improvement, + .final_inertia_check = cpp_params.batched.final_inertia_check, + .max_no_improvement = cpp_params.batched.minibatch.max_no_improvement, + .reassignment_ratio = cpp_params.batched.minibatch.reassignment_ratio, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; }); diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index b991ed2621..59178d994b 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -125,16 +125,6 @@ struct params : base_params { */ int batch_centroids = 0; - /** - * Centroid update mode for fit_batched(): - * - FullBatch (default): Standard Lloyd's algorithm. Accumulate partial sums - * across all batches, update centroids once per iteration. Deterministic and - * mathematically equivalent to standard k-means. - * - MiniBatch: Online mini-batch k-means. Update centroids incrementally after - * each randomly sampled batch. Faster convergence but non-deterministic. - */ - CentroidUpdateMode update_mode = FullBatch; - /** * If true, check inertia during iterations for early convergence (used by both fit and * fit_batched). @@ -142,20 +132,54 @@ struct params : base_params { bool inertia_check = false; /** - * If true, compute the final inertia after fit_batched completes. This requires an additional - * full pass over all the host data, which can be expensive for large datasets. - * Only used by fit_batched(); regular fit() always computes final inertia. - * Default: false (skip final inertia computation for performance). + * Parameters specific to batched k-means (fit_batched). + * These parameters are only used when calling fit_batched() and are ignored by regular fit(). */ - bool final_inertia_check = false; + struct batched_params { + /** + * Centroid update mode for fit_batched(): + * - FullBatch (default): Standard Lloyd's algorithm. Accumulate partial sums + * across all batches, update centroids once per iteration. Deterministic and + * mathematically equivalent to standard k-means. + * - MiniBatch: Online mini-batch k-means. Update centroids incrementally after + * each randomly sampled batch. Faster convergence but non-deterministic. + */ + CentroidUpdateMode update_mode = FullBatch; - /** - * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia - * before early stopping. Only used when update_mode is MiniBatch. - * If None/0, this convergence criterion is disabled. - * Default: 10 - */ - int max_no_improvement = 10; + /** + * If true, compute the final inertia after fit_batched completes. This requires an additional + * full pass over all the host data, which can be expensive for large datasets. + * Only used by fit_batched(); regular fit() always computes final inertia. + * Default: false (skip final inertia computation for performance). + */ + bool final_inertia_check = false; + + /** + * Parameters specific to mini-batch k-means mode. + * These parameters are only used when update_mode is MiniBatch. + */ + struct minibatch_params { + /** + * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia + * before early stopping. Only used when update_mode is MiniBatch. + * If 0, this convergence criterion is disabled. + * Default: 10 + */ + int max_no_improvement = 10; + + /** + * Control the fraction of the maximum number of counts for a center to be reassigned. + * Centers with count < reassignment_ratio * max(counts) are randomly reassigned to + * observations from the current batch. A higher value means that low count centers are + * more likely to be reassigned, which means that the model will take longer to converge, + * but should converge in a better clustering. + * Only used when update_mode is MiniBatch. + * If 0.0, reassignment is disabled. + * Default: 0.01 (matching scikit-learn) + */ + double reassignment_ratio = 0.01; + } minibatch; + } batched; }; /** @@ -207,7 +231,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * raft::resources handle; * cuvs::cluster::kmeans::params params; * params.n_clusters = 100; - * // params.update_mode = kmeans::params::MiniBatch; // for mini-batch mode + * // params.batched.update_mode = kmeans::params::MiniBatch; // for mini-batch mode * int n_features = 15; * float inertia; * int n_iter; @@ -230,7 +254,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @endcode * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. Use params.update_mode + * @param[in] params Parameters for KMeans model. Use params.batched.update_mode * to select FullBatch or MiniBatch mode. * @param[in] X Training instances on HOST memory. The data must * be in row-major format. diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index e6cfe1086d..b9a0a699c7 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -17,18 +17,27 @@ #include #include #include +#include #include #include #include #include #include +#include +#include #include +#include #include #include #include #include +#include +#include +#include +#include +#include #include #include @@ -192,16 +201,23 @@ void accumulate_batch_centroids( * centroid[k] = centroid[k] + learning_rate[k] * (batch_mean[k] - centroid[k]) * * This is equivalent to a weighted average where total_count tracks cumulative weight. + * + * Optionally reassigns low-count clusters to random samples from the current batch. */ template void minibatch_update_centroids(raft::resources const& handle, raft::device_matrix_view centroids, raft::device_matrix_view batch_sums, raft::device_vector_view batch_counts, - raft::device_vector_view total_counts) + raft::device_vector_view total_counts, + raft::device_matrix_view batch_data, + double reassignment_ratio, + IdxT current_batch_size, + std::mt19937& rng) { auto n_clusters = centroids.extent(0); auto n_features = centroids.extent(1); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); raft::linalg::matrix_vector_op( @@ -236,6 +252,201 @@ void minibatch_update_centroids(raft::resources const& handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(batch_means.view()), centroids); + + // Reassignment logic: reassign low-count clusters to random samples from current batch + if (reassignment_ratio > 0.0) { + auto max_count_scalar = raft::make_device_scalar(handle, MathT{0}); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Max(nullptr, + temp_storage_bytes, + total_counts.data_handle(), + max_count_scalar.data_handle(), + n_clusters, + stream); + rmm::device_uvector temp_storage(temp_storage_bytes, stream); + cub::DeviceReduce::Max(temp_storage.data(), + temp_storage_bytes, + total_counts.data_handle(), + max_count_scalar.data_handle(), + n_clusters, + stream); + MathT max_count = max_count_scalar.value(stream); + raft::resource::sync_stream(handle, stream); + + // Identify clusters to reassign on device: total_count < reassignment_ratio * max_count or total_count == 0 + MathT threshold = static_cast(reassignment_ratio) * max_count; + auto reassign_flags = raft::make_device_vector(handle, n_clusters); + + raft::linalg::unaryOp(total_counts.data_handle(), + reassign_flags.data_handle(), + n_clusters, + [=] __device__(MathT count) { + return (count < threshold || count == MathT{0}) ? uint8_t{1} : uint8_t{0}; + }, + stream); + + // Count how many clusters need reassignment using RAFT mapThenSumReduce + auto num_reassign_scalar = raft::make_device_scalar(handle, IdxT{0}); + raft::linalg::mapThenSumReduce(num_reassign_scalar.data_handle(), + n_clusters, + raft::identity_op{}, + stream, + reassign_flags.data_handle()); + IdxT num_to_reassign = num_reassign_scalar.value(stream); + raft::resource::sync_stream(handle, stream); + + // Limit to 50% of batch size + IdxT max_reassign = static_cast(0.5 * current_batch_size); + if (num_to_reassign > max_reassign) { + // Need to select only the worst ones - do sorting on device + // First, get all cluster indices that need reassignment + auto all_reassign_indices = raft::make_device_vector(handle, num_to_reassign); + auto counting_iter = thrust::counting_iterator(0); + thrust::device_ptr flags_ptr(reassign_flags.data_handle()); + + auto end_iter = thrust::copy_if(raft::resource::get_thrust_policy(handle), + counting_iter, + counting_iter + n_clusters, + flags_ptr, + thrust::device_pointer_cast(all_reassign_indices.data_handle()), + [] __device__(uint8_t flag) { return flag == 1; }); + + // Get counts for these clusters using RAFT matrix::gather + auto reassign_counts = raft::make_device_vector(handle, num_to_reassign); + auto total_counts_matrix_view = raft::make_device_matrix_view( + total_counts.data_handle(), n_clusters, 1); + auto reassign_indices_view = raft::make_device_vector_view( + all_reassign_indices.data_handle(), num_to_reassign); + auto reassign_counts_matrix_view = raft::make_device_matrix_view( + reassign_counts.data_handle(), num_to_reassign, 1); + raft::matrix::gather(handle, total_counts_matrix_view, reassign_indices_view, reassign_counts_matrix_view); + + thrust::sort_by_key(raft::resource::get_thrust_policy(handle), + reassign_counts.data_handle(), + reassign_counts.data_handle() + num_to_reassign, + all_reassign_indices.data_handle()); + + // Reset all flags + raft::matrix::fill(handle, reassign_flags.view(), uint8_t{0}); + + // Set flags only for worst max_reassign clusters + auto worst_indices = raft::make_device_vector(handle, max_reassign); + raft::copy(worst_indices.data_handle(), + all_reassign_indices.data_handle(), + max_reassign, + stream); + + // Use RAFT matrix::scatter to set flags + auto flags_scatter = raft::make_device_vector(handle, max_reassign); + raft::matrix::fill(handle, flags_scatter.view(), uint8_t{1}); + auto flags_scatter_matrix_view = raft::make_device_matrix_view( + flags_scatter.data_handle(), max_reassign, 1); + auto worst_indices_view = raft::make_device_vector_view( + worst_indices.data_handle(), max_reassign); + auto reassign_flags_matrix_view = raft::make_device_matrix_view( + reassign_flags.data_handle(), n_clusters, 1); + raft::matrix::scatter(handle, flags_scatter_matrix_view, worst_indices_view, reassign_flags_matrix_view); + + num_to_reassign = max_reassign; + } + + if (num_to_reassign > 0) { + // Get list of cluster indices to reassign using thrust::copy_if + auto reassign_indices = raft::make_device_vector(handle, num_to_reassign); + auto counting_iter = thrust::counting_iterator(0); + thrust::device_ptr flags_ptr(reassign_flags.data_handle()); + + auto end_iter = thrust::copy_if(raft::resource::get_thrust_policy(handle), + counting_iter, + counting_iter + n_clusters, + flags_ptr, + thrust::device_pointer_cast(reassign_indices.data_handle()), + [] __device__(uint8_t flag) { return flag == 1; }); + + // Verify actual count using RAFT (in case flags were modified) + auto actual_count_scalar = raft::make_device_scalar(handle, IdxT{0}); + raft::linalg::mapThenSumReduce(actual_count_scalar.data_handle(), + n_clusters, + raft::identity_op{}, + stream, + reassign_flags.data_handle()); + num_to_reassign = actual_count_scalar.value(stream); + raft::resource::sync_stream(handle, stream); + + auto reassign_indices_host = raft::make_host_vector(handle, num_to_reassign); + raft::copy(reassign_indices_host.data_handle(), reassign_indices.data_handle(), num_to_reassign, stream); + raft::resource::sync_stream(handle, stream); + + // Pick random samples from current batch (without replacement) on host + std::uniform_int_distribution batch_dist(0, current_batch_size - 1); + std::unordered_set selected_indices; + selected_indices.reserve(num_to_reassign); + + while (static_cast(selected_indices.size()) < num_to_reassign) { + IdxT idx = batch_dist(rng); + selected_indices.insert(idx); + } + + std::vector new_center_indices(selected_indices.begin(), selected_indices.end()); + + // Update centroids for reassigned clusters (device-to-device copy from batch_data) + for (IdxT i = 0; i < num_to_reassign; ++i) { + IdxT cluster_idx = reassign_indices_host.data_handle()[i]; + IdxT sample_idx = new_center_indices[i]; + raft::copy(centroids.data_handle() + cluster_idx * n_features, + batch_data.data_handle() + sample_idx * n_features, + n_features, + stream); + } + + // Reset total_counts for reassigned clusters to min of non-reassigned clusters on device + // Find min of non-reassigned clusters on device + auto masked_counts = raft::make_device_vector(handle, n_clusters); + auto total_counts_ptr = total_counts.data_handle(); + auto reassign_flags_ptr = reassign_flags.data_handle(); + // Use RAFT map_offset to create masked array (replaces thrust::make_counting_iterator) + raft::linalg::map_offset(handle, + masked_counts.view(), + [=] __device__(IdxT k) { + if (reassign_flags_ptr[k] == 0 && total_counts_ptr[k] > MathT{0}) { + return total_counts_ptr[k]; + } + return max_count; + }); + + auto min_non_reassigned_scalar = raft::make_device_scalar(handle, max_count); + size_t min_temp_storage_bytes = 0; + cub::DeviceReduce::Min(nullptr, + min_temp_storage_bytes, + masked_counts.data_handle(), + min_non_reassigned_scalar.data_handle(), + n_clusters, + stream); + rmm::device_uvector min_temp_storage(min_temp_storage_bytes, stream); + cub::DeviceReduce::Min(min_temp_storage.data(), + min_temp_storage_bytes, + masked_counts.data_handle(), + min_non_reassigned_scalar.data_handle(), + n_clusters, + stream); + MathT min_non_reassigned = min_non_reassigned_scalar.value(stream); + if (min_non_reassigned == max_count) { + min_non_reassigned = MathT{1}; // Fallback if all clusters were reassigned + } + + // Update total_counts on device for reassigned clusters + // reassign_indices_host is already available from earlier + for (IdxT i = 0; i < num_to_reassign; ++i) { + IdxT cluster_idx = reassign_indices_host.data_handle()[i]; + raft::copy(total_counts.data_handle() + cluster_idx, + &min_non_reassigned, + 1, + stream); + } + + RAFT_LOG_DEBUG("KMeans minibatch: Reassigned %zu cluster centers", static_cast(num_to_reassign)); + } + } } /** @@ -290,7 +501,7 @@ void fit(raft::resources const& handle, } bool use_minibatch = - (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); + (params.batched.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); @@ -434,8 +645,15 @@ void fit(raft::resources const& handle, auto cluster_counts_const = raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - minibatch_update_centroids( - handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + minibatch_update_centroids(handle, + centroids, + centroid_sums_const, + cluster_counts_const, + total_counts.view(), + batch_data_view, + params.batched.minibatch.reassignment_ratio, + current_batch_size, + rng); // Compute squared difference of centers (for convergence check) auto sqrdNorm = raft::make_device_scalar(handle, T{0}); @@ -481,7 +699,7 @@ void fit(raft::resources const& handle, // Early stopping: lack of improvement in smoothed inertia // Disabled if max_no_improvement == 0 - if (params.max_no_improvement > 0) { + if (params.batched.minibatch.max_no_improvement > 0) { if (ewa_inertia < ewa_inertia_min) { no_improvement = 0; ewa_inertia_min = ewa_inertia; @@ -489,7 +707,7 @@ void fit(raft::resources const& handle, no_improvement++; } - if (no_improvement >= params.max_no_improvement) { + if (no_improvement >= params.batched.minibatch.max_no_improvement) { RAFT_LOG_DEBUG("KMeans minibatch: Converged (lack of improvement) at step %d/%d", n_iter[0], n_steps); @@ -617,7 +835,7 @@ void fit(raft::resources const& handle, } } - if (params.final_inertia_check) { + if (params.batched.final_inertia_check) { inertia[0] = 0; using namespace cuvs::spatial::knn::detail::utils; batch_load_iterator data_batches(X.data_handle(), n_samples, n_features, batch_size, stream); diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 8fbd7577ed..f8ffc02acf 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -40,6 +40,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: bool inertia_check, bool final_inertia_check, int max_no_improvement, + double reassignment_ratio, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 8e2b32c7e7..c48d36a18a 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -100,6 +100,11 @@ cdef class KMeansParams: in smoothed inertia before early stopping. Only used when update_mode is "mini_batch". If 0, this convergence criterion is disabled. Default: 10 (matches sklearn's default). + reassignment_ratio : float + Control the fraction of the maximum number of counts for a center to be reassigned. + Centers with count < reassignment_ratio * max(counts) are randomly reassigned to + observations from the current batch. Only used when update_mode is "mini_batch". + If 0.0, reassignment is disabled. Default: 0.01 (matches sklearn's default). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -128,6 +133,7 @@ cdef class KMeansParams: inertia_check=None, final_inertia_check=None, max_no_improvement=None, + reassignment_ratio=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -158,6 +164,8 @@ cdef class KMeansParams: self.params.final_inertia_check = final_inertia_check if max_no_improvement is not None: self.params.max_no_improvement = max_no_improvement + if reassignment_ratio is not None: + self.params.reassignment_ratio = reassignment_ratio if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -218,6 +226,10 @@ cdef class KMeansParams: def max_no_improvement(self): return self.params.max_no_improvement + @property + def reassignment_ratio(self): + return self.params.reassignment_ratio + @property def hierarchical(self): return self.params.hierarchical From 8b448f66e09e23487f56fc3852819a279889335a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Feb 2026 10:36:55 -0800 Subject: [PATCH 31/81] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 200 ++++++++++++---------- 1 file changed, 108 insertions(+), 92 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index b9a0a699c7..077accfa1f 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -26,18 +26,18 @@ #include #include #include -#include #include #include #include -#include -#include +#include #include #include +#include +#include +#include #include -#include #include #include @@ -215,8 +215,8 @@ void minibatch_update_centroids(raft::resources const& handle, IdxT current_batch_size, std::mt19937& rng) { - auto n_clusters = centroids.extent(0); - auto n_features = centroids.extent(1); + auto n_clusters = centroids.extent(0); + auto n_features = centroids.extent(1); cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); @@ -255,7 +255,7 @@ void minibatch_update_centroids(raft::resources const& handle, // Reassignment logic: reassign low-count clusters to random samples from current batch if (reassignment_ratio > 0.0) { - auto max_count_scalar = raft::make_device_scalar(handle, MathT{0}); + auto max_count_scalar = raft::make_device_scalar(handle, MathT{0}); size_t temp_storage_bytes = 0; cub::DeviceReduce::Max(nullptr, temp_storage_bytes, @@ -270,20 +270,24 @@ void minibatch_update_centroids(raft::resources const& handle, max_count_scalar.data_handle(), n_clusters, stream); - MathT max_count = max_count_scalar.value(stream); + auto max_count_host = raft::make_host_scalar(0); + raft::copy(max_count_host.data_handle(), max_count_scalar.data_handle(), 1, stream); raft::resource::sync_stream(handle, stream); + MathT max_count = max_count_host.data_handle()[0]; - // Identify clusters to reassign on device: total_count < reassignment_ratio * max_count or total_count == 0 - MathT threshold = static_cast(reassignment_ratio) * max_count; + // Identify clusters to reassign on device: total_count < reassignment_ratio * max_count or + // total_count == 0 + MathT threshold = static_cast(reassignment_ratio) * max_count; auto reassign_flags = raft::make_device_vector(handle, n_clusters); - - raft::linalg::unaryOp(total_counts.data_handle(), - reassign_flags.data_handle(), - n_clusters, - [=] __device__(MathT count) { - return (count < threshold || count == MathT{0}) ? uint8_t{1} : uint8_t{0}; - }, - stream); + + raft::linalg::unaryOp( + total_counts.data_handle(), + reassign_flags.data_handle(), + n_clusters, + [=] __device__(MathT count) { + return (count < threshold || count == MathT{0}) ? uint8_t{1} : uint8_t{0}; + }, + stream); // Count how many clusters need reassignment using RAFT mapThenSumReduce auto num_reassign_scalar = raft::make_device_scalar(handle, IdxT{0}); @@ -292,8 +296,10 @@ void minibatch_update_centroids(raft::resources const& handle, raft::identity_op{}, stream, reassign_flags.data_handle()); - IdxT num_to_reassign = num_reassign_scalar.value(stream); + auto num_reassign_host = raft::make_host_scalar(0); + raft::copy(num_reassign_host.data_handle(), num_reassign_scalar.data_handle(), 1, stream); raft::resource::sync_stream(handle, stream); + IdxT num_to_reassign = num_reassign_host.data_handle()[0]; // Limit to 50% of batch size IdxT max_reassign = static_cast(0.5 * current_batch_size); @@ -301,68 +307,65 @@ void minibatch_update_centroids(raft::resources const& handle, // Need to select only the worst ones - do sorting on device // First, get all cluster indices that need reassignment auto all_reassign_indices = raft::make_device_vector(handle, num_to_reassign); - auto counting_iter = thrust::counting_iterator(0); + auto counting_iter = thrust::counting_iterator(0); thrust::device_ptr flags_ptr(reassign_flags.data_handle()); - - auto end_iter = thrust::copy_if(raft::resource::get_thrust_policy(handle), - counting_iter, - counting_iter + n_clusters, - flags_ptr, - thrust::device_pointer_cast(all_reassign_indices.data_handle()), - [] __device__(uint8_t flag) { return flag == 1; }); - + + thrust::copy_if(raft::resource::get_thrust_policy(handle), + counting_iter, + counting_iter + n_clusters, + flags_ptr, + thrust::device_pointer_cast(all_reassign_indices.data_handle()), + [] __device__(uint8_t flag) { return flag == 1; }); + // Get counts for these clusters using RAFT matrix::gather auto reassign_counts = raft::make_device_vector(handle, num_to_reassign); - auto total_counts_matrix_view = raft::make_device_matrix_view( - total_counts.data_handle(), n_clusters, 1); + auto total_counts_matrix_view = + raft::make_device_matrix_view(total_counts.data_handle(), n_clusters, 1); auto reassign_indices_view = raft::make_device_vector_view( all_reassign_indices.data_handle(), num_to_reassign); auto reassign_counts_matrix_view = raft::make_device_matrix_view( reassign_counts.data_handle(), num_to_reassign, 1); - raft::matrix::gather(handle, total_counts_matrix_view, reassign_indices_view, reassign_counts_matrix_view); - + raft::matrix::gather( + handle, total_counts_matrix_view, reassign_indices_view, reassign_counts_matrix_view); + thrust::sort_by_key(raft::resource::get_thrust_policy(handle), reassign_counts.data_handle(), reassign_counts.data_handle() + num_to_reassign, all_reassign_indices.data_handle()); - + // Reset all flags raft::matrix::fill(handle, reassign_flags.view(), uint8_t{0}); - + // Set flags only for worst max_reassign clusters auto worst_indices = raft::make_device_vector(handle, max_reassign); - raft::copy(worst_indices.data_handle(), - all_reassign_indices.data_handle(), - max_reassign, - stream); - - // Use RAFT matrix::scatter to set flags + raft::copy( + worst_indices.data_handle(), all_reassign_indices.data_handle(), max_reassign, stream); + + // Scatter flags: set reassign_flags[worst_indices[i]] = 1 auto flags_scatter = raft::make_device_vector(handle, max_reassign); raft::matrix::fill(handle, flags_scatter.view(), uint8_t{1}); - auto flags_scatter_matrix_view = raft::make_device_matrix_view( - flags_scatter.data_handle(), max_reassign, 1); - auto worst_indices_view = raft::make_device_vector_view( - worst_indices.data_handle(), max_reassign); - auto reassign_flags_matrix_view = raft::make_device_matrix_view( - reassign_flags.data_handle(), n_clusters, 1); - raft::matrix::scatter(handle, flags_scatter_matrix_view, worst_indices_view, reassign_flags_matrix_view); - + thrust::scatter(raft::resource::get_thrust_policy(handle), + flags_scatter.data_handle(), + flags_scatter.data_handle() + max_reassign, + worst_indices.data_handle(), + reassign_flags.data_handle()); + num_to_reassign = max_reassign; } if (num_to_reassign > 0) { // Get list of cluster indices to reassign using thrust::copy_if auto reassign_indices = raft::make_device_vector(handle, num_to_reassign); - auto counting_iter = thrust::counting_iterator(0); + auto counting_iter = thrust::counting_iterator(0); thrust::device_ptr flags_ptr(reassign_flags.data_handle()); - - auto end_iter = thrust::copy_if(raft::resource::get_thrust_policy(handle), - counting_iter, - counting_iter + n_clusters, - flags_ptr, - thrust::device_pointer_cast(reassign_indices.data_handle()), - [] __device__(uint8_t flag) { return flag == 1; }); - + + thrust::copy_if(raft::resource::get_thrust_policy(handle), + counting_iter, + counting_iter + n_clusters, + flags_ptr, + thrust::device_pointer_cast(reassign_indices.data_handle()), + [] __device__(uint8_t flag) { return flag == 1; }); + // Verify actual count using RAFT (in case flags were modified) auto actual_count_scalar = raft::make_device_scalar(handle, IdxT{0}); raft::linalg::mapThenSumReduce(actual_count_scalar.data_handle(), @@ -370,11 +373,16 @@ void minibatch_update_centroids(raft::resources const& handle, raft::identity_op{}, stream, reassign_flags.data_handle()); - num_to_reassign = actual_count_scalar.value(stream); + auto actual_count_host = raft::make_host_scalar(0); + raft::copy(actual_count_host.data_handle(), actual_count_scalar.data_handle(), 1, stream); raft::resource::sync_stream(handle, stream); - - auto reassign_indices_host = raft::make_host_vector(handle, num_to_reassign); - raft::copy(reassign_indices_host.data_handle(), reassign_indices.data_handle(), num_to_reassign, stream); + num_to_reassign = actual_count_host.data_handle()[0]; + + auto reassign_indices_host = raft::make_host_vector(num_to_reassign); + raft::copy(reassign_indices_host.data_handle(), + reassign_indices.data_handle(), + num_to_reassign, + stream); raft::resource::sync_stream(handle, stream); // Pick random samples from current batch (without replacement) on host @@ -401,35 +409,36 @@ void minibatch_update_centroids(raft::resources const& handle, // Reset total_counts for reassigned clusters to min of non-reassigned clusters on device // Find min of non-reassigned clusters on device - auto masked_counts = raft::make_device_vector(handle, n_clusters); - auto total_counts_ptr = total_counts.data_handle(); + auto masked_counts = raft::make_device_vector(handle, n_clusters); + auto total_counts_ptr = total_counts.data_handle(); auto reassign_flags_ptr = reassign_flags.data_handle(); - // Use RAFT map_offset to create masked array (replaces thrust::make_counting_iterator) - raft::linalg::map_offset(handle, - masked_counts.view(), - [=] __device__(IdxT k) { - if (reassign_flags_ptr[k] == 0 && total_counts_ptr[k] > MathT{0}) { - return total_counts_ptr[k]; - } - return max_count; - }); + raft::linalg::map_offset(handle, masked_counts.view(), [=] __device__(IdxT k) { + if (reassign_flags_ptr[k] == 0 && total_counts_ptr[k] > MathT{0}) { + return total_counts_ptr[k]; + } + return max_count; + }); auto min_non_reassigned_scalar = raft::make_device_scalar(handle, max_count); - size_t min_temp_storage_bytes = 0; + size_t min_temp_storage_bytes = 0; cub::DeviceReduce::Min(nullptr, - min_temp_storage_bytes, - masked_counts.data_handle(), - min_non_reassigned_scalar.data_handle(), - n_clusters, - stream); + min_temp_storage_bytes, + masked_counts.data_handle(), + min_non_reassigned_scalar.data_handle(), + n_clusters, + stream); rmm::device_uvector min_temp_storage(min_temp_storage_bytes, stream); cub::DeviceReduce::Min(min_temp_storage.data(), - min_temp_storage_bytes, - masked_counts.data_handle(), - min_non_reassigned_scalar.data_handle(), - n_clusters, - stream); - MathT min_non_reassigned = min_non_reassigned_scalar.value(stream); + min_temp_storage_bytes, + masked_counts.data_handle(), + min_non_reassigned_scalar.data_handle(), + n_clusters, + stream); + auto min_non_reassigned_host = raft::make_host_scalar(0); + raft::copy( + min_non_reassigned_host.data_handle(), min_non_reassigned_scalar.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + MathT min_non_reassigned = min_non_reassigned_host.data_handle()[0]; if (min_non_reassigned == max_count) { min_non_reassigned = MathT{1}; // Fallback if all clusters were reassigned } @@ -438,13 +447,11 @@ void minibatch_update_centroids(raft::resources const& handle, // reassign_indices_host is already available from earlier for (IdxT i = 0; i < num_to_reassign; ++i) { IdxT cluster_idx = reassign_indices_host.data_handle()[i]; - raft::copy(total_counts.data_handle() + cluster_idx, - &min_non_reassigned, - 1, - stream); + raft::copy(total_counts.data_handle() + cluster_idx, &min_non_reassigned, 1, stream); } - RAFT_LOG_DEBUG("KMeans minibatch: Reassigned %zu cluster centers", static_cast(num_to_reassign)); + RAFT_LOG_DEBUG("KMeans minibatch: Reassigned %zu cluster centers", + static_cast(num_to_reassign)); } } } @@ -626,7 +633,10 @@ void fit(raft::resources const& handle, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, raft::add_op{}); - batch_inertia = clusterCostD.value(stream) / static_cast(current_batch_size); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + batch_inertia = clusterCost_host.data_handle()[0] / static_cast(current_batch_size); raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); @@ -790,7 +800,10 @@ void fit(raft::resources const& handle, raft::make_device_scalar_view(clusterCostD.data()), raft::value_op{}, raft::add_op{}); - total_cost += clusterCostD.value(stream); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + total_cost += clusterCost_host.data_handle()[0]; } } @@ -880,7 +893,10 @@ void fit(raft::resources const& handle, raft::value_op{}, raft::add_op{}); - inertia[0] += clusterCostD.value(stream); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + inertia[0] += clusterCost_host.data_handle()[0]; } RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); } else { From 7c3965cc16629ddd5efd10e64a1840b9ebc2655d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Feb 2026 13:29:15 -0800 Subject: [PATCH 32/81] simplify minibatch update step --- cpp/src/cluster/detail/kmeans_batched.cuh | 90 +++++++++-------------- cpp/src/cluster/detail/kmeans_common.cuh | 1 + 2 files changed, 34 insertions(+), 57 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 077accfa1f..97f2c0189e 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -196,11 +196,13 @@ void accumulate_batch_centroids( /** * @brief Update centroids using mini-batch online learning * - * Uses the online update formula: - * learning_rate[k] = batch_count[k] / (total_count[k] + batch_count[k]) - * centroid[k] = centroid[k] + learning_rate[k] * (batch_mean[k] - centroid[k]) + * Updates centroids using the following formula (matching scikit-learn's implementation): * - * This is equivalent to a weighted average where total_count tracks cumulative weight. + * centroid_new[k] = (centroid_old[k] * old_total_counts[k] + batch_sums[k]) / total_counts[k] + * + * This is equivalent to the learning rate formula: + * learning_rate[k] = batch_counts[k] / total_counts[k] + * centroid[k] = centroid[k] * (1 - learning_rate[k]) + batch_means[k] * learning_rate[k] * * Optionally reassigns low-count clusters to random samples from the current batch. */ @@ -219,39 +221,22 @@ void minibatch_update_centroids(raft::resources const& handle, auto n_features = centroids.extent(1); cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); - raft::linalg::matrix_vector_op( - handle, batch_sums, batch_counts, batch_means.view(), raft::div_checkzero_op{}); + raft::linalg::matrix_vector_op(handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(total_counts), + centroids, + raft::mul_op{}); + + raft::linalg::add( + handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(batch_sums), centroids); raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); - // lr[k] = batch_count[k] / total_count[k] (after update) - auto learning_rates = raft::make_device_vector(handle, n_clusters); - raft::linalg::map(handle, - learning_rates.view(), - raft::div_checkzero_op{}, - batch_counts, - raft::make_const_mdspan(total_counts)); - - // centroid = (1 - lr) * centroid + lr * batch_mean - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(centroids), - raft::make_const_mdspan(learning_rates.view()), - centroids, - [] __device__(MathT centroid_val, MathT lr) { return (MathT{1} - lr) * centroid_val; }); - - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(batch_means.view()), - raft::make_const_mdspan(learning_rates.view()), - batch_means.view(), - [] __device__(MathT mean_val, MathT lr) { return lr * mean_val; }); - - raft::linalg::add(handle, - raft::make_const_mdspan(centroids), - raft::make_const_mdspan(batch_means.view()), - centroids); + raft::linalg::matrix_vector_op(handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(total_counts), + centroids, + raft::div_checkzero_op{}); // Reassignment logic: reassign low-count clusters to random samples from current batch if (reassignment_ratio > 0.0) { @@ -275,8 +260,6 @@ void minibatch_update_centroids(raft::resources const& handle, raft::resource::sync_stream(handle, stream); MathT max_count = max_count_host.data_handle()[0]; - // Identify clusters to reassign on device: total_count < reassignment_ratio * max_count or - // total_count == 0 MathT threshold = static_cast(reassignment_ratio) * max_count; auto reassign_flags = raft::make_device_vector(handle, n_clusters); @@ -289,7 +272,6 @@ void minibatch_update_centroids(raft::resources const& handle, }, stream); - // Count how many clusters need reassignment using RAFT mapThenSumReduce auto num_reassign_scalar = raft::make_device_scalar(handle, IdxT{0}); raft::linalg::mapThenSumReduce(num_reassign_scalar.data_handle(), n_clusters, @@ -317,7 +299,6 @@ void minibatch_update_centroids(raft::resources const& handle, thrust::device_pointer_cast(all_reassign_indices.data_handle()), [] __device__(uint8_t flag) { return flag == 1; }); - // Get counts for these clusters using RAFT matrix::gather auto reassign_counts = raft::make_device_vector(handle, num_to_reassign); auto total_counts_matrix_view = raft::make_device_matrix_view(total_counts.data_handle(), n_clusters, 1); @@ -333,7 +314,6 @@ void minibatch_update_centroids(raft::resources const& handle, reassign_counts.data_handle() + num_to_reassign, all_reassign_indices.data_handle()); - // Reset all flags raft::matrix::fill(handle, reassign_flags.view(), uint8_t{0}); // Set flags only for worst max_reassign clusters @@ -341,7 +321,6 @@ void minibatch_update_centroids(raft::resources const& handle, raft::copy( worst_indices.data_handle(), all_reassign_indices.data_handle(), max_reassign, stream); - // Scatter flags: set reassign_flags[worst_indices[i]] = 1 auto flags_scatter = raft::make_device_vector(handle, max_reassign); raft::matrix::fill(handle, flags_scatter.view(), uint8_t{1}); thrust::scatter(raft::resource::get_thrust_policy(handle), @@ -354,7 +333,7 @@ void minibatch_update_centroids(raft::resources const& handle, } if (num_to_reassign > 0) { - // Get list of cluster indices to reassign using thrust::copy_if + // Get list of cluster indices to reassign auto reassign_indices = raft::make_device_vector(handle, num_to_reassign); auto counting_iter = thrust::counting_iterator(0); thrust::device_ptr flags_ptr(reassign_flags.data_handle()); @@ -366,7 +345,6 @@ void minibatch_update_centroids(raft::resources const& handle, thrust::device_pointer_cast(reassign_indices.data_handle()), [] __device__(uint8_t flag) { return flag == 1; }); - // Verify actual count using RAFT (in case flags were modified) auto actual_count_scalar = raft::make_device_scalar(handle, IdxT{0}); raft::linalg::mapThenSumReduce(actual_count_scalar.data_handle(), n_clusters, @@ -397,7 +375,6 @@ void minibatch_update_centroids(raft::resources const& handle, std::vector new_center_indices(selected_indices.begin(), selected_indices.end()); - // Update centroids for reassigned clusters (device-to-device copy from batch_data) for (IdxT i = 0; i < num_to_reassign; ++i) { IdxT cluster_idx = reassign_indices_host.data_handle()[i]; IdxT sample_idx = new_center_indices[i]; @@ -407,8 +384,8 @@ void minibatch_update_centroids(raft::resources const& handle, stream); } - // Reset total_counts for reassigned clusters to min of non-reassigned clusters on device - // Find min of non-reassigned clusters on device + // Reset total_counts for reassigned clusters to min of non-reassigned clusters. Note that + // this will affect the learning rate directly. auto masked_counts = raft::make_device_vector(handle, n_clusters); auto total_counts_ptr = total_counts.data_handle(); auto reassign_flags_ptr = reassign_flags.data_handle(); @@ -470,6 +447,10 @@ void minibatch_update_centroids(raft::resources const& handle, * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run + * + * @note For mini-batch mode: When sample weights are provided, they are used as sampling + * probabilities (normalized) to select minibatch samples. Unit weights are then passed + * to the centroid update to avoid double weighting (matching scikit-learn's approach). */ template void fit(raft::resources const& handle, @@ -533,6 +514,9 @@ void fit(raft::resources const& handle, : raft::make_host_vector(0); std::mt19937 rng(params.rng_state.seed); std::uniform_int_distribution uniform_dist(0, n_samples - 1); + // Weighted sampling: if sample weights are provided, use them as sampling probabilities. + // Since the sampling is weight-aware, we pass unit weights to the centroid update + // to avoid accounting for the weights twice (matching scikit-learn's approach). std::discrete_distribution weighted_dist; bool use_weighted_sampling = false; if (use_minibatch && sample_weight) { @@ -552,6 +536,8 @@ void fit(raft::resources const& handle, if (use_minibatch) { raft::matrix::fill(handle, total_counts.view(), T{0}); + // Fill once before the loop since batch_size is constant. + raft::matrix::fill(handle, batch_weights.view(), T{1}); n_steps = (params.max_iter * n_samples) / batch_size; } @@ -583,10 +569,6 @@ void fit(raft::resources const& handle, current_batch_size * n_features, stream); - auto batch_weights_view = - raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - raft::matrix::fill(handle, batch_weights_view, T{1}); - auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto batch_weights_view_const = raft::make_device_vector_view( @@ -734,7 +716,6 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); - // Use centroids from start of iteration for all batches (must remain constant) auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); @@ -948,29 +929,24 @@ void predict(raft::resources const& handle, current_batch_size * n_features, stream); - std::optional> batch_weight_view; if (sample_weight) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + batch_idx, current_batch_size, stream); - batch_weight_view = raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size); } auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_labels_view = - raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); T batch_inertia = 0; cuvs::cluster::kmeans::detail::kmeans_predict( handle, params, batch_data_view, - batch_weight_view, + batch_weights.view(), centroids, - batch_labels_view, + batch_labels.view(), normalize_weight, raft::make_host_scalar_view(&batch_inertia)); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 7219db4bfc..fb8a4b615e 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include From c4902081348af34ff0e3dbf9ab677fe454857560 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Feb 2026 16:55:25 -0800 Subject: [PATCH 33/81] fix oom --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 97f2c0189e..1d1ce2244a 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -264,8 +264,8 @@ void minibatch_update_centroids(raft::resources const& handle, auto reassign_flags = raft::make_device_vector(handle, n_clusters); raft::linalg::unaryOp( - total_counts.data_handle(), reassign_flags.data_handle(), + total_counts.data_handle(), n_clusters, [=] __device__(MathT count) { return (count < threshold || count == MathT{0}) ? uint8_t{1} : uint8_t{0}; From b09611afd97f527b220f876b2fd35704ea7cc900 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Feb 2026 22:37:54 -0800 Subject: [PATCH 34/81] update tests --- python/cuvs/cuvs/tests/test_kmeans.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 963105d8c6..0c577adc1f 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -128,24 +128,27 @@ def test_fit_batched_matches_fit( @pytest.mark.parametrize("n_rows", [1000]) @pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [8]) -@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("n_clusters", [8, 16, 32]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): """ Test that fit_batched matches sklearn's KMeans implementation. """ rng = np.random.default_rng(99) X_host = rng.random((n_rows, n_cols)).astype(dtype) + norms = np.linalg.norm(X_host, ord=1, axis=1, keepdims=True) + norms = np.where(norms == 0, 1.0, norms) + X_host = X_host / norms initial_centroids_host = X_host[:n_clusters].copy() # Sklearn fit kmeans = MiniBatchKMeans( - n_clusters=8, + n_clusters=n_clusters, init=initial_centroids_host, max_iter=100, verbose=0, random_state=None, - tol=0.0, + tol=1e-4, max_no_improvement=10, init_size=None, n_init="auto", @@ -176,7 +179,7 @@ def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): centroids_cuvs = centroids_cuvs.copy_to_host() assert np.allclose( - centroids_sklearn, centroids_cuvs, rtol=0.3, atol=0.3 + centroids_sklearn, centroids_cuvs, rtol=0.1, atol=0.1 ), f"max diff: {np.max(np.abs(centroids_sklearn - centroids_cuvs))}" inertia_diff = abs(inertia_sklearn - inertia_cuvs) From 350ee82b874cbf654233c567c3b2aa5a1676d6ea Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 14:54:46 -0800 Subject: [PATCH 35/81] update n_init use --- cpp/include/cuvs/cluster/kmeans.hpp | 4 +- cpp/src/cluster/detail/kmeans_batched.cuh | 1065 +++++++++++++-------- 2 files changed, 666 insertions(+), 403 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 59178d994b..5f74ce92f4 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -103,7 +103,9 @@ struct params : base_params { raft::random::RngState rng_state{0}; /** - * Number of instance k-means algorithm will be run with different seeds. + * Number of instance k-means algorithm will be run with different seeds. For MiniBatch mode, + * this is the number of different initializations to try, but the algorithm is only run once with + * the best initialization. */ int n_init = 1; diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 1d1ce2244a..9438d59be9 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -41,6 +41,7 @@ #include #include +#include #include #include #include @@ -49,16 +50,22 @@ namespace cuvs::cluster::kmeans::detail { /** - * @brief Sample data from host to device for initialization, with optional type conversion + * @brief Sample data from host to device for initialization/validation. + * + * When sample weights are provided, the corresponding weights are gathered + * alongside the sampled rows and copied to device * * @tparam T Input data type * @tparam IdxT Index type */ template -void prepare_init_sample(raft::resources const& handle, - raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed) +void prepare_init_sample( + raft::resources const& handle, + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed, + std::optional> weight_in = std::nullopt, + std::optional> weight_out = std::nullopt) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -74,12 +81,15 @@ void prepare_init_sample(raft::resources const& handle, selected_indices.reserve(n_samples_out); while (static_cast(selected_indices.size()) < n_samples_out) { - IdxT idx = dist(gen); - selected_indices.insert(idx); + selected_indices.insert(dist(gen)); } std::vector indices(selected_indices.begin(), selected_indices.end()); + bool copy_weights = weight_in.has_value() && weight_out.has_value(); + std::vector host_weights; + if (copy_weights) { host_weights.resize(n_samples_out); } + std::vector host_sample(n_samples_out * n_features); #pragma omp parallel for for (IdxT i = 0; i < static_cast(n_samples_out); i++) { @@ -87,9 +97,86 @@ void prepare_init_sample(raft::resources const& handle, std::memcpy(host_sample.data() + i * n_features, X.data_handle() + src_idx * n_features, n_features * sizeof(T)); + if (copy_weights) { host_weights[i] = weight_in->data_handle()[src_idx]; } } raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); + if (copy_weights) { + raft::copy(weight_out->data_handle(), host_weights.data(), n_samples_out, stream); + } +} + +/** + * @brief Compute (optionally weighted) inertia. + * + * Used for scoring the validation set during n_init selection, where the data is + * small enough that no host→device batching is needed. + */ +template +T compute_inertia( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace, + std::optional> sample_weight = std::nullopt) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto metric = params.metric; + + auto minClusterAndDistance = + raft::make_device_vector, IdxT>(handle, n_samples); + auto L2NormX = raft::make_device_vector(handle, n_samples); + rmm::device_uvector L2NormBuf(0, stream); + rmm::device_scalar cost(stream); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data_handle(), X.data_handle(), n_features, n_samples, stream); + } + + auto mcd_view = minClusterAndDistance.view(); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X, + centroids, + mcd_view, + raft::make_device_vector_view(L2NormX.data_handle(), n_samples), + L2NormBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + if (sample_weight) { + raft::linalg::map( + handle, + mcd_view, + [=] __device__(const raft::KeyValuePair kvp, T wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(mcd_view), + *sample_weight); + } + + cuvs::cluster::kmeans::detail::computeClusterCost(handle, + mcd_view, + workspace, + raft::make_device_scalar_view(cost.data()), + raft::value_op{}, + raft::add_op{}); + + T result = 0; + raft::copy(&result, cost.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + return result; } /** @@ -110,9 +197,9 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - size_t init_sample_size = - std::min(static_cast(n_samples), - std::max(static_cast(3 * n_clusters), static_cast(10000))); + size_t init_sample_size = 3 * params.batch_size; + if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } + init_sample_size = std::min(init_sample_size, n_samples); RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); @@ -433,6 +520,109 @@ void minibatch_update_centroids(raft::resources const& handle, } } +/** + * @brief Compute total inertia over host data using batched GPU processing. + * + * Iterates over the host data in batches, computing the (optionally weighted) + * sum of squared distances from each sample to its nearest centroid. + * + * @param[in] sample_weight Optional per-sample weights on host. When provided, + * each squared distance is multiplied by its weight + * before summing (matching sklearn's weighted inertia). + */ +template +T compute_batched_host_inertia( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace, + std::optional> sample_weight = std::nullopt) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto metric = params.metric; + + IdxT effective_batch = std::min(batch_size, static_cast(n_samples)); + auto minClusterAndDistance = + raft::make_device_vector, IdxT>(handle, effective_batch); + auto L2NormBatch = raft::make_device_vector(handle, effective_batch); + rmm::device_uvector L2NormBuf(0, stream); + rmm::device_scalar cost(stream); + + // Device buffer for per-batch weights (only used when sample_weight is provided) + auto batch_weights = + raft::make_device_vector(handle, sample_weight ? effective_batch : IdxT{0}); + + T total_inertia = 0; + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches(X.data_handle(), n_samples, n_features, batch_size, stream); + + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); + auto batch_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); + auto mcd_view = raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); + } + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_view, + centroids, + mcd_view, + raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size), + L2NormBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Apply sample weights to distances before summing (matching sklearn weighted inertia) + if (sample_weight) { + auto weight_offset = static_cast(data_batch.offset()); + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + weight_offset, + current_batch_size, + stream); + + raft::linalg::map( + handle, + mcd_view, + [=] __device__(const raft::KeyValuePair kvp, T wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(mcd_view), + raft::make_device_vector_view(batch_weights.data_handle(), + current_batch_size)); + } + + cuvs::cluster::kmeans::detail::computeClusterCost(handle, + mcd_view, + workspace, + raft::make_device_scalar_view(cost.data()), + raft::value_op{}, + raft::add_op{}); + + T batch_cost = 0; + raft::copy(&batch_cost, cost.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + total_inertia += batch_cost; + } + + return total_inertia; +} + /** * @brief Main fit function for batched k-means with host data * @@ -484,14 +674,27 @@ void fit(raft::resources const& handle, rmm::device_uvector workspace(0, stream); - if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample(handle, params, X, centroids, workspace); - } - bool use_minibatch = (params.batched.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); + auto n_init = params.n_init; + if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { + RAFT_LOG_DEBUG( + "Explicit initial center position passed: performing only one init in " + "k-means instead of n_init=%d", + n_init); + n_init = 1; + } + + auto best_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + T best_inertia = std::numeric_limits::max(); + IdxT best_n_iter = 0; + + std::mt19937 gen(params.rng_state.seed); + bool compute_final_inertia = (n_init > 1) || params.batched.final_inertia_check; + + // ----- Allocate reusable work buffers (shared across n_init iterations) ----- auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); auto batch_weights = raft::make_device_vector(handle, batch_size); auto minClusterAndDistance = @@ -504,19 +707,16 @@ void fit(raft::resources const& handle, auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); rmm::device_scalar clusterCostD(stream); - T priorClusteringCost = 0; - // Mini-batch only state + // Mini-batch only buffers auto total_counts = raft::make_device_vector(handle, use_minibatch ? n_clusters : 0); auto host_batch_buffer = use_minibatch ? raft::make_host_matrix(batch_size, n_features) : raft::make_host_matrix(0, n_features); auto batch_indices = use_minibatch ? raft::make_host_vector(batch_size) : raft::make_host_vector(0); - std::mt19937 rng(params.rng_state.seed); - std::uniform_int_distribution uniform_dist(0, n_samples - 1); - // Weighted sampling: if sample weights are provided, use them as sampling probabilities. - // Since the sampling is weight-aware, we pass unit weights to the centroid update - // to avoid accounting for the weights twice (matching scikit-learn's approach). + auto prev_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // Weighted sampling (shared across n_init, weights are constant) std::discrete_distribution weighted_dist; bool use_weighted_sampling = false; if (use_minibatch && sample_weight) { @@ -525,230 +725,152 @@ void fit(raft::resources const& handle, weighted_dist = std::discrete_distribution(weights.begin(), weights.end()); use_weighted_sampling = true; } - IdxT n_steps = params.max_iter; - - // Mini-batch convergence tracking - T ewa_inertia = T{0}; - T ewa_inertia_min = T{0}; - int no_improvement = 0; - bool ewa_initialized = false; - auto prev_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - - if (use_minibatch) { - raft::matrix::fill(handle, total_counts.view(), T{0}); - // Fill once before the loop since batch_size is constant. - raft::matrix::fill(handle, batch_weights.view(), T{1}); - n_steps = (params.max_iter * n_samples) / batch_size; - } - for (n_iter[0] = 1; n_iter[0] <= n_steps; ++n_iter[0]) { - RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); + // n_init only selects the best *initialization using a validation set. The full training loop + // runs once with the best init. + bool minibatch_init_done = false; + if (use_minibatch && n_init > 1) { + size_t valid_size = + std::min(static_cast(n_samples), + std::max(static_cast(3 * n_clusters), static_cast(10000))); + RAFT_LOG_DEBUG("KMeans minibatch: creating validation set of %zu samples for n_init selection", + valid_size); - raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + auto X_valid = raft::make_device_matrix(handle, valid_size, n_features); + auto valid_weights = raft::make_device_vector(handle, sample_weight ? valid_size : 0); + std::optional> valid_weight_view; - T total_cost = 0; + if (sample_weight) { + prepare_init_sample(handle, X, X_valid.view(), gen(), *sample_weight, valid_weights.view()); + valid_weight_view = raft::make_device_vector_view( + valid_weights.data_handle(), static_cast(valid_size)); + } else { + prepare_init_sample(handle, X, X_valid.view(), gen()); + } - if (use_minibatch) { - IdxT current_batch_size = batch_size; + auto X_valid_const = raft::make_const_mdspan(X_valid.view()); - for (IdxT i = 0; i < current_batch_size; ++i) { - batch_indices.data_handle()[i] = - use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); - } + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { + cuvs::cluster::kmeans::params iter_params = params; + iter_params.rng_state.seed = gen(); -#pragma omp parallel for - for (IdxT i = 0; i < current_batch_size; ++i) { - IdxT sample_idx = batch_indices.data_handle()[i]; - std::memcpy(host_batch_buffer.data_handle() + i * n_features, - X.data_handle() + sample_idx * n_features, - n_features * sizeof(T)); - } + RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d init selection (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); - raft::copy(batch_data.data_handle(), - host_batch_buffer.data_handle(), - current_batch_size * n_features, - stream); + init_centroids_from_host_sample(handle, iter_params, X, centroids, workspace); - auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view_const = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + T valid_inertia = compute_inertia( + handle, iter_params, X_valid_const, centroids_const, workspace, valid_weight_view); + + RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d validation inertia=%f", + seed_iter + 1, + n_init, + static_cast(valid_inertia)); + + if (valid_inertia < best_inertia) { + best_inertia = valid_inertia; + raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); } + } - // Save centroids before update for convergence check - raft::copy(prev_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); + best_inertia = std::numeric_limits::max(); + n_init = 1; + force_inertia = false; - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Compute batch inertia (normalized by batch_size for comparison) - T batch_inertia = 0; - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - batch_inertia = clusterCost_host.data_handle()[0] / static_cast(current_batch_size); - - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, cluster_counts.view(), T{0}); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); - - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view_const, - centroid_sums.view(), - cluster_counts.view()); - - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - - minibatch_update_centroids(handle, - centroids, - centroid_sums_const, - cluster_counts_const, - total_counts.view(), - batch_data_view, - params.batched.minibatch.reassignment_ratio, - current_batch_size, - rng); - - // Compute squared difference of centers (for convergence check) - auto sqrdNorm = raft::make_device_scalar(handle, T{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - centroids.size(), - raft::sqdiff_op{}, - stream, - prev_centroids.data_handle(), - centroids.data_handle()); - T centers_squared_diff = 0; - raft::copy(¢ers_squared_diff, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + minibatch_init_done = true; + RAFT_LOG_DEBUG("KMeans minibatch: best initialization selected, proceeding with training"); + } - // Skip first step (inertia from initialization) - if (n_iter[0] > 1) { - // Update Exponentially Weighted Average of inertia - T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); - alpha = std::min(alpha, T{1}); + // ---- Main n_init loop ---- + // For full-batch: runs full training n_init times, keeps best result. + // For minibatch: runs once (n_init was set to 1 above after init selection). + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { + cuvs::cluster::kmeans::params iter_params = params; + iter_params.rng_state.seed = gen(); + + RAFT_LOG_DEBUG("KMeans batched fit: n_init iteration %d/%d (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); + + if (!minibatch_init_done && + iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + init_centroids_from_host_sample(handle, iter_params, X, centroids, workspace); + } - if (!ewa_initialized) { - ewa_inertia = batch_inertia; - ewa_inertia_min = batch_inertia; - ewa_initialized = true; - } else { - ewa_inertia = ewa_inertia * (T{1} - alpha) + batch_inertia * alpha; - } + // Reset per-iteration state + T priorClusteringCost = 0; + IdxT n_steps = iter_params.max_iter; - RAFT_LOG_DEBUG( - "KMeans minibatch step %d/%d: batch_inertia=%f, ewa_inertia=%f, centers_squared_diff=%f", - n_iter[0], - n_steps, - static_cast(batch_inertia), - static_cast(ewa_inertia), - static_cast(centers_squared_diff)); - - // Early stopping: absolute tolerance on squared change of centers - // Disabled if tol == 0.0 - if (params.tol > 0.0 && centers_squared_diff <= params.tol) { - RAFT_LOG_DEBUG( - "KMeans minibatch: Converged (small centers change) at step %d/%d", n_iter[0], n_steps); - break; - } + std::mt19937 rng(iter_params.rng_state.seed); + std::uniform_int_distribution uniform_dist(0, n_samples - 1); + T ewa_inertia = T{0}; + T ewa_inertia_min = T{0}; + int no_improvement = 0; + bool ewa_initialized = false; - // Early stopping: lack of improvement in smoothed inertia - // Disabled if max_no_improvement == 0 - if (params.batched.minibatch.max_no_improvement > 0) { - if (ewa_inertia < ewa_inertia_min) { - no_improvement = 0; - ewa_inertia_min = ewa_inertia; - } else { - no_improvement++; - } + if (use_minibatch) { + raft::matrix::fill(handle, total_counts.view(), T{0}); + raft::matrix::fill(handle, batch_weights.view(), T{1}); + n_steps = (iter_params.max_iter * n_samples) / batch_size; + } - if (no_improvement >= params.batched.minibatch.max_no_improvement) { - RAFT_LOG_DEBUG("KMeans minibatch: Converged (lack of improvement) at step %d/%d", - n_iter[0], - n_steps); - break; - } - } - } else { - RAFT_LOG_DEBUG("KMeans minibatch step %d/%d: mean batch inertia: %f", - n_iter[0], - n_steps, - static_cast(batch_inertia)); - } - } else { - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, cluster_counts.view(), T{0}); + for (n_iter[0] = 1; n_iter[0] <= n_steps; ++n_iter[0]) { + RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); + raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - using namespace cuvs::spatial::knn::detail::utils; - batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, batch_size, stream); + T total_cost = 0; - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); + if (use_minibatch) { + IdxT current_batch_size = batch_size; - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - auto batch_weights_fill_view = - raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + data_batch.offset(), - current_batch_size, - stream); - } else { - raft::matrix::fill(handle, batch_weights_fill_view, T{1}); + for (IdxT i = 0; i < current_batch_size; ++i) { + batch_indices.data_handle()[i] = + use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); + } + +#pragma omp parallel for + for (IdxT i = 0; i < current_batch_size; ++i) { + IdxT sample_idx = batch_indices.data_handle()[i]; + std::memcpy(host_batch_buffer.data_handle() + i * n_features, + X.data_handle() + sample_idx * n_features, + n_features * sizeof(T)); } - auto batch_weights_view = raft::make_device_vector_view( + raft::copy(batch_data.data_handle(), + host_batch_buffer.data_handle(), + current_batch_size * n_features, + stream); + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_weights_view_const = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); + raft::linalg::rowNorm(L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); } + // Save centroids before update for convergence check + raft::copy(prev_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); @@ -756,7 +878,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance.view(), + minClusterAndDistance_view, L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -764,230 +886,369 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); + // Compute batch inertia (normalized by batch_size for comparison) + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + batch_inertia = clusterCost_host.data_handle()[0] / static_cast(current_batch_size); + + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); + + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, - batch_weights_view, + batch_weights_view_const, centroid_sums.view(), cluster_counts.view()); - if (params.inertia_check) { - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - total_cost += clusterCost_host.data_handle()[0]; + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + + minibatch_update_centroids(handle, + centroids, + centroid_sums_const, + cluster_counts_const, + total_counts.view(), + batch_data_view, + params.batched.minibatch.reassignment_ratio, + current_batch_size, + rng); + + // Compute squared difference of centers (for convergence check) + auto sqrdNorm = raft::make_device_scalar(handle, T{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + centroids.size(), + raft::sqdiff_op{}, + stream, + prev_centroids.data_handle(), + centroids.data_handle()); + T centers_squared_diff = 0; + raft::copy(¢ers_squared_diff, sqrdNorm.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + + // Skip first step (inertia from initialization) + if (n_iter[0] > 1) { + // Update Exponentially Weighted Average of inertia + T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); + alpha = std::min(alpha, T{1}); + + if (!ewa_initialized) { + ewa_inertia = batch_inertia; + ewa_inertia_min = batch_inertia; + ewa_initialized = true; + } else { + ewa_inertia = ewa_inertia * (T{1} - alpha) + batch_inertia * alpha; + } + + RAFT_LOG_DEBUG( + "KMeans minibatch step %d/%d: batch_inertia=%f, ewa_inertia=%f, " + "centers_squared_diff=%f", + n_iter[0], + n_steps, + static_cast(batch_inertia), + static_cast(ewa_inertia), + static_cast(centers_squared_diff)); + + // Early stopping: absolute tolerance on squared change of centers + // Disabled if tol == 0.0 + if (params.tol > 0.0 && centers_squared_diff <= params.tol) { + RAFT_LOG_DEBUG("KMeans minibatch: Converged (small centers change) at step %d/%d", + n_iter[0], + n_steps); + break; + } + + // Early stopping: lack of improvement in smoothed inertia + // Disabled if max_no_improvement == 0 + if (params.batched.minibatch.max_no_improvement > 0) { + if (ewa_inertia < ewa_inertia_min) { + no_improvement = 0; + ewa_inertia_min = ewa_inertia; + } else { + no_improvement++; + } + + if (no_improvement >= params.batched.minibatch.max_no_improvement) { + RAFT_LOG_DEBUG("KMeans minibatch: Converged (lack of improvement) at step %d/%d", + n_iter[0], + n_steps); + break; + } + } + } else { + RAFT_LOG_DEBUG("KMeans minibatch step %d/%d: mean batch inertia: %f", + n_iter[0], + n_steps, + static_cast(batch_inertia)); } - } + } else { + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); - finalize_centroids( - handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); - } + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, batch_size, stream); - // Convergence check for full-batch mode only - if (!use_minibatch) { - auto sqrdNorm = raft::make_device_scalar(handle, T{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - centroids.size(), - raft::sqdiff_op{}, - stream, - new_centroids.data_handle(), - centroids.data_handle()); + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); - raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); + auto batch_weights_fill_view = + raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + data_batch.offset(), + current_batch_size, + stream); + } else { + raft::matrix::fill(handle, batch_weights_fill_view, T{1}); + } - T sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); - bool done = false; - if (params.inertia_check && n_iter[0] > 1) { - T delta = total_cost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - priorClusteringCost = total_cost; - } + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); + } - raft::resource::sync_stream(handle, stream); - if (sqrdNormError < params.tol) done = true; + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); - if (done) { - RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); - break; + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance.view(), + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); + + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + total_cost += clusterCost_host.data_handle()[0]; + } + } + + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); } - } - } - if (params.batched.final_inertia_check) { - inertia[0] = 0; - using namespace cuvs::spatial::knn::detail::utils; - batch_load_iterator data_batches(X.data_handle(), n_samples, n_features, batch_size, stream); + // Convergence check for full-batch mode only + if (!use_minibatch) { + auto sqrdNorm = raft::make_device_scalar(handle, T{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + centroids.size(), + raft::sqdiff_op{}, + stream, + new_centroids.data_handle(), + centroids.data_handle()); + + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); + + T sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + + bool done = false; + if (params.inertia_check && n_iter[0] > 1) { + T delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + priorClusteringCost = total_cost; + } - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; + } } + } + if (compute_final_inertia) { auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); + inertia[0] = compute_batched_host_inertia( + handle, iter_params, X, batch_size, centroids_const, workspace, sample_weight); + + RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", + seed_iter + 1, + n_init, + static_cast(inertia[0])); + + if (inertia[0] < best_inertia) { + best_inertia = inertia[0]; + best_n_iter = n_iter[0]; + raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + raft::resource::sync_stream(handle, stream); + } + } else { + inertia[0] = 0; + RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed (inertia computation skipped)", + seed_iter + 1, + n_init); + } - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + if (n_init > 1) { + raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); raft::resource::sync_stream(handle, stream); - inertia[0] += clusterCost_host.data_handle()[0]; + inertia[0] = best_inertia; + n_iter[0] = best_n_iter; + RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", + n_init, + static_cast(best_inertia), + best_n_iter); } - RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", static_cast(inertia[0])); - } else { - inertia[0] = 0; - RAFT_LOG_DEBUG("KMeans batched: Completed (inertia computation skipped)"); } -} - -/** - * @brief Predict cluster labels for host data using batched processing. - * - * @tparam T Input data type (float, double) - * @tparam IdxT Index type (int, int64_t) - */ -template -void predict(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); - RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), - "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); + /** + * @brief Predict cluster labels for host data using batched processing. + * + * @tparam T Input data type (float, double) + * @tparam IdxT Index type (int, int64_t) + */ + template + void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); + + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_labels = raft::make_device_vector(handle, batch_size); - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); - auto batch_weights = raft::make_device_vector(handle, batch_size); - auto batch_labels = raft::make_device_vector(handle, batch_size); + inertia[0] = 0; - inertia[0] = 0; + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); - for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } - if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + batch_idx, - current_batch_size, - stream); - } + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); - auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::kmeans_predict( + handle, + params, + batch_data_view, + batch_weights.view(), + centroids, + batch_labels.view(), + normalize_weight, + raft::make_host_scalar_view(&batch_inertia)); - T batch_inertia = 0; - cuvs::cluster::kmeans::detail::kmeans_predict( - handle, - params, - batch_data_view, - batch_weights.view(), - centroids, - batch_labels.view(), - normalize_weight, - raft::make_host_scalar_view(&batch_inertia)); + raft::copy( + labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); - raft::copy( - labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); + inertia[0] += batch_inertia; + } - inertia[0] += batch_inertia; + raft::resource::sync_stream(handle, stream); } - raft::resource::sync_stream(handle, stream); -} - -/** - * @brief Fit k-means and predict cluster labels using batched processing. - */ -template -void fit_predict(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - T fit_inertia = 0; - fit(handle, - params, - X, - batch_size, - sample_weight, - centroids, - raft::make_host_scalar_view(&fit_inertia), - n_iter); - - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - - predict( - handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); -} + /** + * @brief Fit k-means and predict cluster labels using batched processing. + */ + template + void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + { + T fit_inertia = 0; + fit(handle, + params, + X, + batch_size, + sample_weight, + centroids, + raft::make_host_scalar_view(&fit_inertia), + n_iter); + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + + predict( + handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); + } } // namespace cuvs::cluster::kmeans::detail From 63a34a3fc9436e8ba18337402ea0eb085ce06e15 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 15:12:30 -0800 Subject: [PATCH 36/81] abstract away commonalities into helpers --- cpp/src/cluster/detail/kmeans.cuh | 65 ++--- cpp/src/cluster/detail/kmeans_batched.cuh | 309 ++++++++-------------- cpp/src/cluster/detail/kmeans_common.cuh | 100 +++++++ 3 files changed, 230 insertions(+), 244 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 181038e167..f46a9beec7 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -399,18 +399,13 @@ void kmeans_fit_main(raft::resources const& handle, newCentroids.view(), workspace); - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - newCentroids.size(), - raft::sqdiff_op{}, - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view()); + // Compute how much centroids shifted + DataT sqrdNormError = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + newCentroids.data_handle(), n_clusters, n_features)); raft::copy(handle, raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), @@ -440,7 +435,6 @@ void kmeans_fit_main(raft::resources const& handle, priorClusteringCost = curClusteringCost; } - raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -449,43 +443,14 @@ void kmeans_fit_main(raft::resources const& handle, } } - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - raft::linalg::map( - handle, - minClusterAndDistance.view(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_const_mdspan(weight)); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - inertia[0] = clusterCostD.value(stream); + inertia[0] = + compute_inertia(handle, + params, + X, + raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features), + workspace, + weight); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 9438d59be9..e8108fb7c8 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -106,79 +106,6 @@ void prepare_init_sample( } } -/** - * @brief Compute (optionally weighted) inertia. - * - * Used for scoring the validation set during n_init selection, where the data is - * small enough that no host→device batching is needed. - */ -template -T compute_inertia( - raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - std::optional> sample_weight = std::nullopt) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto metric = params.metric; - - auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, n_samples); - auto L2NormX = raft::make_device_vector(handle, n_samples); - rmm::device_uvector L2NormBuf(0, stream); - rmm::device_scalar cost(stream); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), n_features, n_samples, stream); - } - - auto mcd_view = minClusterAndDistance.view(); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - mcd_view, - raft::make_device_vector_view(L2NormX.data_handle(), n_samples), - L2NormBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - if (sample_weight) { - raft::linalg::map( - handle, - mcd_view, - [=] __device__(const raft::KeyValuePair kvp, T wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(mcd_view), - *sample_weight); - } - - cuvs::cluster::kmeans::detail::computeClusterCost(handle, - mcd_view, - workspace, - raft::make_device_scalar_view(cost.data()), - raft::value_op{}, - raft::add_op{}); - - T result = 0; - raft::copy(&result, cost.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - return result; -} - /** * @brief Initialize centroids using k-means++ on a sample of the host data * @@ -779,9 +706,9 @@ void fit(raft::resources const& handle, raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); - best_inertia = std::numeric_limits::max(); - n_init = 1; - force_inertia = false; + best_inertia = std::numeric_limits::max(); + n_init = 1; + compute_final_inertia = params.batched.final_inertia_check; minibatch_init_done = true; RAFT_LOG_DEBUG("KMeans minibatch: best initialization selected, proceeding with training"); @@ -928,16 +855,12 @@ void fit(raft::resources const& handle, rng); // Compute squared difference of centers (for convergence check) - auto sqrdNorm = raft::make_device_scalar(handle, T{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - centroids.size(), - raft::sqdiff_op{}, - stream, - prev_centroids.data_handle(), - centroids.data_handle()); - T centers_squared_diff = 0; - raft::copy(¢ers_squared_diff, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + T centers_squared_diff = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + prev_centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features)); // Skip first step (inertia from initialization) if (n_iter[0] > 1) { @@ -1081,27 +1004,24 @@ void fit(raft::resources const& handle, // Convergence check for full-batch mode only if (!use_minibatch) { - auto sqrdNorm = raft::make_device_scalar(handle, T{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - centroids.size(), - raft::sqdiff_op{}, - stream, - new_centroids.data_handle(), - centroids.data_handle()); + T sqrdNormError = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + new_centroids.data_handle(), n_clusters, n_features)); raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); - T sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); - bool done = false; - if (params.inertia_check && n_iter[0] > 1) { - T delta = total_cost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; + if (params.inertia_check) { + if (n_iter[0] > 1) { + T delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } priorClusteringCost = total_cost; } - raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -1146,109 +1066,110 @@ void fit(raft::resources const& handle, best_n_iter); } } +} - /** - * @brief Predict cluster labels for host data using batched processing. - * - * @tparam T Input data type (float, double) - * @tparam IdxT Index type (int, int64_t) - */ - template - void predict(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) - { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - - RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); - RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), - "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); - - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); - auto batch_weights = raft::make_device_vector(handle, batch_size); - auto batch_labels = raft::make_device_vector(handle, batch_size); - - inertia[0] = 0; - - for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); - - raft::copy(batch_data.data_handle(), - X.data_handle() + batch_idx * n_features, - current_batch_size * n_features, - stream); +/** + * @brief Predict cluster labels for host data using batched processing. + * + * @tparam T Input data type (float, double) + * @tparam IdxT Index type (int, int64_t) + */ +template +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; - if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + batch_idx, - current_batch_size, - stream); - } + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); - auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_labels = raft::make_device_vector(handle, batch_size); - T batch_inertia = 0; - cuvs::cluster::kmeans::detail::kmeans_predict( - handle, - params, - batch_data_view, - batch_weights.view(), - centroids, - batch_labels.view(), - normalize_weight, - raft::make_host_scalar_view(&batch_inertia)); + inertia[0] = 0; - raft::copy( - labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); - inertia[0] += batch_inertia; + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); } - raft::resource::sync_stream(handle, stream); - } + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::kmeans_predict( + handle, + params, + batch_data_view, + batch_weights.view(), + centroids, + batch_labels.view(), + normalize_weight, + raft::make_host_scalar_view(&batch_inertia)); + + raft::copy( + labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); - /** - * @brief Fit k-means and predict cluster labels using batched processing. - */ - template - void fit_predict(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IdxT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - { - T fit_inertia = 0; - fit(handle, - params, - X, - batch_size, - sample_weight, - centroids, - raft::make_host_scalar_view(&fit_inertia), - n_iter); - - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - - predict( - handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); + inertia[0] += batch_inertia; } + raft::resource::sync_stream(handle, stream); +} + +/** + * @brief Fit k-means and predict cluster labels using batched processing. + */ +template +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + T fit_inertia = 0; + fit(handle, + params, + X, + batch_size, + sample_weight, + centroids, + raft::make_host_scalar_view(&fit_inertia), + n_iter); + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + + predict( + handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); +} + } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index fb8a4b615e..5ae29d6bb2 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -576,4 +577,103 @@ void finalize_centroids(raft::resources const& handle, stream); } +/** + * @brief Compute the squared norm difference between two centroid sets. + * + * Returns sum((old_centroids - new_centroids)^2). + * Used for convergence checking in both full-batch and mini-batch modes. + */ +template +DataT compute_centroid_shift(raft::resources const& handle, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + old_centroids.size(), + raft::sqdiff_op{}, + stream, + old_centroids.data_handle(), + new_centroids.data_handle()); + DataT result = 0; + raft::copy(&result, sqrdNorm.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + return result; +} + +/** + * @brief Compute (optionally weighted) inertia for device-resident data. + * + * Computes the sum of (optionally weighted) squared distances from each sample + * to its nearest centroid. Used for final inertia reporting after training and + * for scoring validation sets during n_init selection. + */ +template +DataT compute_inertia( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace, + std::optional> sample_weight = std::nullopt) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto metric = params.metric; + + auto minClusterAndDistance = + raft::make_device_vector, IndexT>(handle, n_samples); + auto L2NormX = raft::make_device_vector(handle, n_samples); + rmm::device_uvector L2NormBuf(0, stream); + rmm::device_scalar cost(stream); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data_handle(), X.data_handle(), n_features, n_samples, stream); + } + + auto mcd_view = minClusterAndDistance.view(); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X, + centroids, + mcd_view, + raft::make_device_vector_view(L2NormX.data_handle(), n_samples), + L2NormBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + if (sample_weight) { + raft::linalg::map( + handle, + mcd_view, + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(mcd_view), + *sample_weight); + } + + cuvs::cluster::kmeans::detail::computeClusterCost(handle, + mcd_view, + workspace, + raft::make_device_scalar_view(cost.data()), + raft::value_op{}, + raft::add_op{}); + + DataT result = 0; + raft::copy(&result, cost.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + return result; +} + } // namespace cuvs::cluster::kmeans::detail From de34c93f4eccbd487d020c1c9d6c7743be40ba84 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 15:40:17 -0800 Subject: [PATCH 37/81] fix compilation errors --- cpp/src/cluster/detail/kmeans_batched.cuh | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index e8108fb7c8..f23447c3b6 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -115,6 +115,7 @@ void prepare_init_sample( template void init_centroids_from_host_sample(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, + IdxT batch_size, raft::host_matrix_view X, raft::device_matrix_view centroids, rmm::device_uvector& workspace) @@ -124,10 +125,9 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - size_t init_sample_size = 3 * params.batch_size; + IdxT init_sample_size = 3 * batch_size; if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } init_sample_size = std::min(init_sample_size, n_samples); - RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); @@ -668,7 +668,12 @@ void fit(raft::resources const& handle, std::optional> valid_weight_view; if (sample_weight) { - prepare_init_sample(handle, X, X_valid.view(), gen(), *sample_weight, valid_weights.view()); + prepare_init_sample(handle, + X, + X_valid.view(), + gen(), + std::optional>{*sample_weight}, + std::optional>{valid_weights.view()}); valid_weight_view = raft::make_device_vector_view( valid_weights.data_handle(), static_cast(valid_size)); } else { @@ -686,7 +691,7 @@ void fit(raft::resources const& handle, n_init, (unsigned long long)iter_params.rng_state.seed); - init_centroids_from_host_sample(handle, iter_params, X, centroids, workspace); + init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); @@ -728,7 +733,7 @@ void fit(raft::resources const& handle, if (!minibatch_init_done && iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample(handle, iter_params, X, centroids, workspace); + init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); } // Reset per-iteration state From 29a2358ea8a07afa339b53405bcfadf4169096e4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 17:07:27 -0800 Subject: [PATCH 38/81] fix bug, add cpp tests --- cpp/src/cluster/detail/kmeans_batched.cuh | 14 +- cpp/tests/cluster/kmeans.cu | 354 ++++++++++++++++++++++ 2 files changed, 366 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index f23447c3b6..00866567d5 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -1126,14 +1126,24 @@ void predict(raft::resources const& handle, auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); + std::optional> batch_weights_view = std::nullopt; + if (sample_weight) { + batch_weights_view = std::make_optional( + raft::make_device_vector_view(batch_weights.data_handle(), + current_batch_size)); + } + + auto batch_labels_view = raft::make_device_vector_view( + batch_labels.data_handle(), current_batch_size); + T batch_inertia = 0; cuvs::cluster::kmeans::detail::kmeans_predict( handle, params, batch_data_view, - batch_weights.view(), + batch_weights_view, centroids, - batch_labels.view(), + batch_labels_view, normalize_weight, raft::make_host_scalar_view(&batch_inertia)); diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 576e6c1a48..fb55a1ee4b 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -346,4 +346,358 @@ TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); +// ============================================================================ +// Batched KMeans Tests (fit_batched + predict_batched) +// ============================================================================ + +template +struct KmeansBatchedInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; + bool minibatch; +}; + +template +class KmeansFitBatchedTest : public ::testing::TestWithParam> { + protected: + KmeansFitBatchedTest() + : d_labels(0, raft::resource::get_cuda_stream(handle)), + d_labels_ref(0, raft::resource::get_cuda_stream(handle)), + d_centroids(0, raft::resource::get_cuda_stream(handle)) + { + } + + void fitBatchedTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + if (testparams.minibatch) { + params.batched.update_mode = cuvs::cluster::kmeans::params::MiniBatch; + params.batched.final_inertia_check = true; + } else { + params.batched.update_mode = cuvs::cluster::kmeans::params::FullBatch; + } + + auto stream = raft::resource::get_cuda_stream(handle); + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + // Copy X to host for batched API + std::vector h_X(n_samples * n_features); + raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + raft::resource::sync_stream(handle, stream); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + + auto h_X_view = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + auto d_centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + + std::optional> h_sw = std::nullopt; + std::vector h_sample_weight; + if (testparams.weighted) { + h_sample_weight.resize(n_samples, T(1)); + h_sw = std::make_optional( + raft::make_host_vector_view(h_sample_weight.data(), n_samples)); + } + + T inertia = 0; + int n_iter = 0; + int batch_size = std::min(n_samples, 256); + + cuvs::cluster::kmeans::fit_batched(handle, + params, + h_X_view, + batch_size, + h_sw, + d_centroids_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + + T pred_inertia = 0; + cuvs::cluster::kmeans::predict( + handle, + params, + raft::make_const_mdspan(X.view()), + std::optional>(std::nullopt), + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), + raft::make_device_vector_view(d_labels.data(), n_samples), + true, + raft::make_host_scalar_view(&pred_inertia)); + + raft::resource::sync_stream(handle, stream); + + score = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + + if (score < 1.0) { + std::stringstream ss; + ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + std::cout << (ss.str().c_str()) << '\n'; + ss.str(std::string()); + ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + std::cout << (ss.str().c_str()) << '\n'; + std::cout << "Score = " << score << '\n'; + } + } + + void SetUp() override { fitBatchedTest(); } + + protected: + raft::resources handle; + KmeansBatchedInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + double score; + cuvs::cluster::kmeans::params params; +}; + +template +class KmeansPredictBatchedTest : public ::testing::TestWithParam> { + protected: + KmeansPredictBatchedTest() + : d_labels(0, raft::resource::get_cuda_stream(handle)), + d_labels_ref(0, raft::resource::get_cuda_stream(handle)), + d_centroids(0, raft::resource::get_cuda_stream(handle)), + d_sample_weight(0, raft::resource::get_cuda_stream(handle)) + { + } + + void predictBatchedTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto stream = raft::resource::get_cuda_stream(handle); + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + + auto d_centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + + std::optional> d_sw = std::nullopt; + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sw = std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), n_samples)); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + T(1)); + } + + T fit_inertia = 0; + int fit_n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + params, + raft::make_const_mdspan(X.view()), + d_sw, + d_centroids_view, + raft::make_host_scalar_view(&fit_inertia), + raft::make_host_scalar_view(&fit_n_iter)); + + std::vector h_X(n_samples * n_features); + raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + raft::resource::sync_stream(handle, stream); + + auto h_X_view = raft::make_host_matrix_view( + h_X.data(), (int64_t)n_samples, (int64_t)n_features); + auto centroids_const_view = raft::make_device_matrix_view( + d_centroids.data(), (int64_t)params.n_clusters, (int64_t)n_features); + + std::vector h_labels(n_samples); + auto h_labels_view = + raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); + + T pred_inertia = 0; + int64_t batch_size = std::min((int64_t)n_samples, (int64_t)256); + + cuvs::cluster::kmeans::predict_batched( + handle, + params, + h_X_view, + batch_size, + std::optional>(std::nullopt), + centroids_const_view, + h_labels_view, + true, + raft::make_host_scalar_view(&pred_inertia)); + + raft::resource::sync_stream(handle, stream); + + std::vector h_labels_int(n_samples); + for (int i = 0; i < n_samples; ++i) { + h_labels_int[i] = static_cast(h_labels[i]); + } + raft::update_device(d_labels.data(), h_labels_int.data(), n_samples, stream); + + score = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + + if (score < 1.0) { + std::stringstream ss; + ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + std::cout << (ss.str().c_str()) << '\n'; + ss.str(std::string()); + ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + std::cout << (ss.str().c_str()) << '\n'; + std::cout << "Score = " << score << '\n'; + } + } + + void SetUp() override { predictBatchedTest(); } + + protected: + raft::resources handle; + KmeansInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_sample_weight; + double score; + cuvs::cluster::kmeans::params params; +}; + +// ============================================================================ +// Test inputs for batched tests +// ============================================================================ + +const std::vector> batched_inputsf2 = { + // FullBatch mode + {1000, 32, 5, 0.0001f, true, false}, + {1000, 32, 5, 0.0001f, false, false}, + {1000, 100, 20, 0.0001f, true, false}, + {1000, 100, 20, 0.0001f, false, false}, + {10000, 32, 10, 0.0001f, true, false}, + {10000, 32, 10, 0.0001f, false, false}, + // MiniBatch mode + {1000, 32, 5, 0.0001f, true, true}, + {1000, 32, 5, 0.0001f, false, true}, + {1000, 100, 20, 0.0001f, true, true}, + {1000, 100, 20, 0.0001f, false, true}, + {10000, 32, 10, 0.0001f, true, true}, + {10000, 32, 10, 0.0001f, false, true}, +}; + +const std::vector> batched_inputsd2 = { + // FullBatch mode + {1000, 32, 5, 0.0001, true, false}, + {1000, 32, 5, 0.0001, false, false}, + {1000, 100, 20, 0.0001, true, false}, + {1000, 100, 20, 0.0001, false, false}, + {10000, 32, 10, 0.0001, true, false}, + {10000, 32, 10, 0.0001, false, false}, + // MiniBatch mode + {1000, 32, 5, 0.0001, true, true}, + {1000, 32, 5, 0.0001, false, true}, + {1000, 100, 20, 0.0001, true, true}, + {1000, 100, 20, 0.0001, false, true}, + {10000, 32, 10, 0.0001, true, true}, + {10000, 32, 10, 0.0001, false, true}, +}; + +// ============================================================================ +// fit_batched tests +// ============================================================================ +typedef KmeansFitBatchedTest KmeansFitBatchedTestF; +typedef KmeansFitBatchedTest KmeansFitBatchedTestD; + +TEST_P(KmeansFitBatchedTestF, Result) +{ + if (testparams.minibatch) { + ASSERT_TRUE(score >= 0.9); + } else { + ASSERT_TRUE(score == 1.0); + } +} + +TEST_P(KmeansFitBatchedTestD, Result) +{ + if (testparams.minibatch) { + ASSERT_TRUE(score >= 0.9); + } else { + ASSERT_TRUE(score == 1.0); + } +} + +INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, + KmeansFitBatchedTestF, + ::testing::ValuesIn(batched_inputsf2)); +INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, + KmeansFitBatchedTestD, + ::testing::ValuesIn(batched_inputsd2)); + +// ============================================================================ +// predict_batched tests +// ============================================================================ +typedef KmeansPredictBatchedTest KmeansPredictBatchedTestF; +typedef KmeansPredictBatchedTest KmeansPredictBatchedTestD; + +TEST_P(KmeansPredictBatchedTestF, Result) { ASSERT_TRUE(score == 1.0); } +TEST_P(KmeansPredictBatchedTestD, Result) { ASSERT_TRUE(score == 1.0); } + +INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, + KmeansPredictBatchedTestF, + ::testing::ValuesIn(inputsf2)); +INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, + KmeansPredictBatchedTestD, + ::testing::ValuesIn(inputsd2)); + } // namespace cuvs From 4f119bac199882b852edb5359f399dee043d69ff Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 17:11:58 -0800 Subject: [PATCH 39/81] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 9 ++++---- cpp/tests/cluster/kmeans.cu | 25 ++++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 00866567d5..94e27475f8 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -1128,13 +1128,12 @@ void predict(raft::resources const& handle, std::optional> batch_weights_view = std::nullopt; if (sample_weight) { - batch_weights_view = std::make_optional( - raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size)); + batch_weights_view = std::make_optional(raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size)); } - auto batch_labels_view = raft::make_device_vector_view( - batch_labels.data_handle(), current_batch_size); + auto batch_labels_view = + raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); T batch_inertia = 0; cuvs::cluster::kmeans::detail::kmeans_predict( diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index fb55a1ee4b..3dbfc09a4c 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -383,7 +383,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam>(std::nullopt), - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), raft::make_device_vector_view(d_labels.data(), n_samples), true, raft::make_host_scalar_view(&pred_inertia)); @@ -566,8 +567,8 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam auto h_labels_view = raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); - T pred_inertia = 0; - int64_t batch_size = std::min((int64_t)n_samples, (int64_t)256); + T pred_inertia = 0; + int64_t batch_size = std::min((int64_t)n_samples, (int64_t)256); cuvs::cluster::kmeans::predict_batched( handle, @@ -678,11 +679,11 @@ TEST_P(KmeansFitBatchedTestD, Result) } INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, - KmeansFitBatchedTestF, - ::testing::ValuesIn(batched_inputsf2)); + KmeansFitBatchedTestF, + ::testing::ValuesIn(batched_inputsf2)); INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, - KmeansFitBatchedTestD, - ::testing::ValuesIn(batched_inputsd2)); + KmeansFitBatchedTestD, + ::testing::ValuesIn(batched_inputsd2)); // ============================================================================ // predict_batched tests @@ -694,10 +695,10 @@ TEST_P(KmeansPredictBatchedTestF, Result) { ASSERT_TRUE(score == 1.0); } TEST_P(KmeansPredictBatchedTestD, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, - KmeansPredictBatchedTestF, - ::testing::ValuesIn(inputsf2)); + KmeansPredictBatchedTestF, + ::testing::ValuesIn(inputsf2)); INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, - KmeansPredictBatchedTestD, - ::testing::ValuesIn(inputsd2)); + KmeansPredictBatchedTestD, + ::testing::ValuesIn(inputsd2)); } // namespace cuvs From bf5726b621e32f76802482b911a05d0787d54dc6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 17:25:57 -0800 Subject: [PATCH 40/81] make cpp tests more rigorous --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 - cpp/tests/cluster/kmeans.cu | 93 ++++++++++++++++++----- 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 94e27475f8..943c6fad89 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -33,7 +33,6 @@ #include #include -#include #include #include #include @@ -42,7 +41,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 3dbfc09a4c..ea0af9f14c 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -366,7 +366,8 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(h_X.data(), n_samples, n_features); auto d_centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + // Run device fit to get reference centroids + std::optional> d_sw = std::nullopt; + rmm::device_uvector d_sample_weight(0, stream); + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sw = std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), n_samples)); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + T(1)); + } + + auto d_centroids_ref_view = + raft::make_device_matrix_view(d_centroids_ref.data(), params.n_clusters, n_features); + T ref_inertia = 0; + int ref_n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + params, + raft::make_const_mdspan(X.view()), + d_sw, + d_centroids_ref_view, + raft::make_host_scalar_view(&ref_inertia), + raft::make_host_scalar_view(&ref_n_iter)); + + raft::copy( + d_centroids.data(), d_centroids_ref.data(), params.n_clusters * n_features, stream); + + cuvs::cluster::kmeans::params batched_params = params; + batched_params.init = cuvs::cluster::kmeans::params::Array; + batched_params.n_init = 1; + std::optional> h_sw = std::nullopt; std::vector h_sample_weight; if (testparams.weighted) { @@ -435,7 +469,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&inertia), raft::make_host_scalar_view(&n_iter)); + raft::resource::sync_stream(handle, stream); + + if (!testparams.minibatch) { + // FullBatch: centroids should match the device fit reference + centroids_match = devArrMatch(d_centroids_ref.data(), + d_centroids.data(), + params.n_clusters, + n_features, + CompareApprox(T(1e-3)), + stream); + } + + // Also check label quality via ARI T pred_inertia = 0; cuvs::cluster::kmeans::predict( handle, params, raft::make_const_mdspan(X.view()), std::optional>(std::nullopt), - raft::make_device_matrix_view( - d_centroids.data(), params.n_clusters, n_features), + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), raft::make_device_vector_view(d_labels.data(), n_samples), true, raft::make_host_scalar_view(&pred_inertia)); @@ -479,7 +525,9 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam d_labels; rmm::device_uvector d_labels_ref; rmm::device_uvector d_centroids; + rmm::device_uvector d_centroids_ref; double score; + testing::AssertionResult centroids_match = testing::AssertionSuccess(); cuvs::cluster::kmeans::params params; }; @@ -528,8 +576,8 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam d_labels.resize(n_samples, stream); d_labels_ref.resize(n_samples, stream); d_centroids.resize(params.n_clusters * n_features, stream); - raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + // Fit on device to get centroids auto d_centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); @@ -554,6 +602,17 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam raft::make_host_scalar_view(&fit_inertia), raft::make_host_scalar_view(&fit_n_iter)); + T ref_inertia = 0; + cuvs::cluster::kmeans::predict( + handle, + params, + raft::make_const_mdspan(X.view()), + std::optional>(std::nullopt), + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), + raft::make_device_vector_view(d_labels_ref.data(), n_samples), + true, + raft::make_host_scalar_view(&ref_inertia)); + std::vector h_X(n_samples * n_features); raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); raft::resource::sync_stream(handle, stream); @@ -589,18 +648,10 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam } raft::update_device(d_labels.data(), h_labels_int.data(), n_samples, stream); - score = raft::stats::adjusted_rand_index( - d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); - - if (score < 1.0) { - std::stringstream ss; - ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); - std::cout << (ss.str().c_str()) << '\n'; - ss.str(std::string()); - ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); - std::cout << (ss.str().c_str()) << '\n'; - std::cout << "Score = " << score << '\n'; - } + // Compare labels directly: predict_batched should produce exact same labels + // as device predict given the same centroids + labels_match = devArrMatch( + d_labels_ref.data(), d_labels.data(), n_samples, Compare(), stream); } void SetUp() override { predictBatchedTest(); } @@ -612,7 +663,7 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam rmm::device_uvector d_labels_ref; rmm::device_uvector d_centroids; rmm::device_uvector d_sample_weight; - double score; + testing::AssertionResult labels_match = testing::AssertionSuccess(); cuvs::cluster::kmeans::params params; }; @@ -665,6 +716,7 @@ TEST_P(KmeansFitBatchedTestF, Result) if (testparams.minibatch) { ASSERT_TRUE(score >= 0.9); } else { + ASSERT_TRUE(centroids_match); ASSERT_TRUE(score == 1.0); } } @@ -674,6 +726,7 @@ TEST_P(KmeansFitBatchedTestD, Result) if (testparams.minibatch) { ASSERT_TRUE(score >= 0.9); } else { + ASSERT_TRUE(centroids_match); ASSERT_TRUE(score == 1.0); } } @@ -691,8 +744,8 @@ INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, typedef KmeansPredictBatchedTest KmeansPredictBatchedTestF; typedef KmeansPredictBatchedTest KmeansPredictBatchedTestD; -TEST_P(KmeansPredictBatchedTestF, Result) { ASSERT_TRUE(score == 1.0); } -TEST_P(KmeansPredictBatchedTestD, Result) { ASSERT_TRUE(score == 1.0); } +TEST_P(KmeansPredictBatchedTestF, Result) { ASSERT_TRUE(labels_match); } +TEST_P(KmeansPredictBatchedTestD, Result) { ASSERT_TRUE(labels_match); } INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, KmeansPredictBatchedTestF, From 74ec7284cdeb04f2dffb72988b69b43839343223 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 2 Mar 2026 18:32:55 -0800 Subject: [PATCH 41/81] style --- cpp/tests/cluster/kmeans.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index ea0af9f14c..b95fceaedc 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -449,8 +449,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&ref_inertia), raft::make_host_scalar_view(&ref_n_iter)); - raft::copy( - d_centroids.data(), d_centroids_ref.data(), params.n_clusters * n_features, stream); + raft::copy(d_centroids.data(), d_centroids_ref.data(), params.n_clusters * n_features, stream); cuvs::cluster::kmeans::params batched_params = params; batched_params.init = cuvs::cluster::kmeans::params::Array; @@ -496,7 +495,8 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam>(std::nullopt), - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), raft::make_device_vector_view(d_labels.data(), n_samples), true, raft::make_host_scalar_view(&pred_inertia)); @@ -608,7 +608,8 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam params, raft::make_const_mdspan(X.view()), std::optional>(std::nullopt), - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features), + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), raft::make_device_vector_view(d_labels_ref.data(), n_samples), true, raft::make_host_scalar_view(&ref_inertia)); @@ -650,8 +651,8 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam // Compare labels directly: predict_batched should produce exact same labels // as device predict given the same centroids - labels_match = devArrMatch( - d_labels_ref.data(), d_labels.data(), n_samples, Compare(), stream); + labels_match = + devArrMatch(d_labels_ref.data(), d_labels.data(), n_samples, Compare(), stream); } void SetUp() override { predictBatchedTest(); } From 568904d6704e4f29d61b9c39fa29f81c5092122d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Mar 2026 10:45:34 -0800 Subject: [PATCH 42/81] fix learning rate bug --- cpp/src/cluster/detail/kmeans_batched.cuh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 943c6fad89..662cd2c4fb 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -867,9 +867,12 @@ void fit(raft::resources const& handle, // Skip first step (inertia from initialization) if (n_iter[0] > 1) { - // Update Exponentially Weighted Average of inertia - T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); - alpha = std::min(alpha, T{1}); + // Update Exponentially Weighted Average of inertia. + // alpha = 2 * batch_size / (n_samples_seen + batch_size) + int64_t n_samples_seen = static_cast(n_iter[0]) * current_batch_size; + T alpha = static_cast(current_batch_size * 2.0) / + static_cast(n_samples_seen + current_batch_size); + alpha = std::min(alpha, T{1}); if (!ewa_initialized) { ewa_inertia = batch_inertia; @@ -881,12 +884,13 @@ void fit(raft::resources const& handle, RAFT_LOG_DEBUG( "KMeans minibatch step %d/%d: batch_inertia=%f, ewa_inertia=%f, " - "centers_squared_diff=%f", + "centers_squared_diff=%f, alpha=%f", n_iter[0], n_steps, static_cast(batch_inertia), static_cast(ewa_inertia), - static_cast(centers_squared_diff)); + static_cast(centers_squared_diff), + static_cast(alpha)); // Early stopping: absolute tolerance on squared change of centers // Disabled if tol == 0.0 From 48a776bf4754fa0298482e5693f81a2223137876 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Mar 2026 11:10:53 -0800 Subject: [PATCH 43/81] revert --- cpp/src/cluster/detail/kmeans_batched.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 662cd2c4fb..2c23cee91d 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -868,10 +868,8 @@ void fit(raft::resources const& handle, // Skip first step (inertia from initialization) if (n_iter[0] > 1) { // Update Exponentially Weighted Average of inertia. - // alpha = 2 * batch_size / (n_samples_seen + batch_size) - int64_t n_samples_seen = static_cast(n_iter[0]) * current_batch_size; T alpha = static_cast(current_batch_size * 2.0) / - static_cast(n_samples_seen + current_batch_size); + static_cast(n_samples + 1); alpha = std::min(alpha, T{1}); if (!ewa_initialized) { From b2c8a65f4da83f603e3a4c401d11ba988adb3266 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 4 Mar 2026 15:35:01 -0800 Subject: [PATCH 44/81] add sample weights --- c/src/cluster/kmeans.cpp | 12 ++++------ cpp/include/cuvs/cluster/kmeans.hpp | 27 ++++++++++++++-------- cpp/src/cluster/detail/kmeans_balanced.cuh | 9 ++++++++ cpp/src/cluster/kmeans_cluster_cost.cu | 12 ++++++---- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 698f387616..622f52ff44 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -63,12 +63,10 @@ void _fit(cuvsResources_t res, RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans"); } else { auto kmeans_params = convert_balanced_params(params); - cuvs::cluster::kmeans::fit(*res_ptr, - kmeans_params, - cuvs::core::from_dlpack(X_tensor), - cuvs::core::from_dlpack(centroids_tensor)); - - *inertia = 0; + T inertia_temp; + auto inertia_view = raft::make_host_scalar_view(&inertia_temp); + cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, X_view, inertia_view); + *inertia = inertia_temp; *n_iter = params.hierarchical_n_iters; } } else { diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 22c8e056ec..4ede31243b 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -1380,7 +1380,7 @@ void transform(raft::resources const& handle, raft::device_matrix_view X_new); /** - * @brief Compute cluster cost + * @brief Compute (optionally weighted) cluster cost * * @param[in] handle The raft handle * @param[in] X Training instances to cluster. The data must @@ -1390,12 +1390,16 @@ void transform(raft::resources const& handle, * row-major format. * [dim = n_clusters x n_features] * @param[out] cost Resulting cluster cost + * @param[in] sample_weight Optional per-sample weights. + * [len = n_samples] * */ -void cluster_cost(const raft::resources& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::host_scalar_view cost); +void cluster_cost( + const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt); /** * @brief Compute cluster cost @@ -1408,12 +1412,15 @@ void cluster_cost(const raft::resources& handle, * row-major format. * [dim = n_clusters x n_features] * @param[out] cost Resulting cluster cost - * + * @param[in] sample_weight Optional per-sample weights. + * [len = n_samples] */ -void cluster_cost(const raft::resources& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::host_scalar_view cost); +void cluster_cost( + const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt); /** * @} diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7582ec900e..1894a9a6d2 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1164,6 +1164,15 @@ void build_hierarchical(const raft::resources& handle, MathT{0.2}, mapping_op, device_memory); + + // Compute inertia if requested + if (inertia != nullptr) { + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(cluster_centers, n_clusters, dim); + cuvs::cluster::kmeans::detail::cluster_cost( + handle, X_view, centroids_view, *inertia, std::nullopt, mapping_op); + } } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_cluster_cost.cu b/cpp/src/cluster/kmeans_cluster_cost.cu index a806cba4a4..3c877f4913 100644 --- a/cpp/src/cluster/kmeans_cluster_cost.cu +++ b/cpp/src/cluster/kmeans_cluster_cost.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -11,16 +11,18 @@ namespace cuvs::cluster::kmeans { void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost) + raft::host_scalar_view cost, + std::optional> sample_weight) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost); + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); } void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost) + raft::host_scalar_view cost, + std::optional> sample_weight) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost); + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); } } // namespace cuvs::cluster::kmeans From 4e8f2e47a26e75c11bf1c4f4924d495498c28959 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 4 Mar 2026 16:14:01 -0800 Subject: [PATCH 45/81] update impl --- cpp/src/cluster/detail/kmeans_balanced.cuh | 14 +++++---- cpp/src/cluster/kmeans.cuh | 35 ++++++++++++++++++---- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 1894a9a6d2..a45493ffb9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1165,13 +1165,15 @@ void build_hierarchical(const raft::resources& handle, mapping_op, device_memory); - // Compute inertia if requested + // Compute inertia if requested (only supported when T == MathT) if (inertia != nullptr) { - auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); - auto centroids_view = - raft::make_device_matrix_view(cluster_centers, n_clusters, dim); - cuvs::cluster::kmeans::detail::cluster_cost( - handle, X_view, centroids_view, *inertia, std::nullopt, mapping_op); + if constexpr (std::is_same_v) { + auto X_view = raft::make_device_matrix_view( + reinterpret_cast(dataset), n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(cluster_centers, n_clusters, dim); + cuvs::cluster::kmeans::detail::cluster_cost(handle, X_view, centroids_view, *inertia); + } } } diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 7d37b9cf80..633c8f346d 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -404,14 +404,27 @@ void min_cluster_distance(raft::resources const& handle, workspace); } +/** + * @brief Compute (optionally weighted) cluster cost (inertia). + * + * @tparam DataT float or double + * @tparam IndexT Index type + * + * @param[in] handle The raft handle + * @param[in] X Input data [n_samples x n_features] + * @param[in] centroids Cluster centroids [n_clusters x n_features] + * @param[out] cost Sum of squared distances to nearest centroid + * @param[in] sample_weight Optional per-sample weights [n_samples] + */ template -void cluster_cost(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::host_scalar_view cost) +void cluster_cost( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt) { - auto stream = raft::resource::get_cuda_stream(handle); - + auto stream = raft::resource::get_cuda_stream(handle); auto n_clusters = centroids.extent(0); auto n_samples = X.extent(0); auto n_features = X.extent(1); @@ -440,6 +453,16 @@ void cluster_cost(raft::resources const& handle, n_clusters, workspace); + // Apply sample weights if provided + if (sample_weight.has_value()) { + raft::linalg::map( + handle, + min_dist.view(), + [] __device__(DataT d, DataT w) { return d * w; }, + raft::make_const_mdspan(min_dist.view()), + sample_weight.value()); + } + auto device_cost = raft::make_device_scalar(handle, DataT(0)); cuvs::cluster::kmeans::cluster_cost( From d6f45247ea70838fa6f779c431853a7aa51eb8d0 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 4 Mar 2026 16:20:37 -0800 Subject: [PATCH 46/81] fix min_cluster_dist --- cpp/src/cluster/kmeans.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 633c8f346d..8e26372b46 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -457,9 +457,9 @@ void cluster_cost( if (sample_weight.has_value()) { raft::linalg::map( handle, - min_dist.view(), + min_cluster_distance.view(), [] __device__(DataT d, DataT w) { return d * w; }, - raft::make_const_mdspan(min_dist.view()), + raft::make_const_mdspan(min_cluster_distance.view()), sample_weight.value()); } From 1fa9013f1534d7da131b7ae032ef55478d74d05f Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 4 Mar 2026 18:02:37 -0800 Subject: [PATCH 47/81] update instantiations --- c/src/cluster/kmeans.cpp | 4 +- cpp/include/cuvs/cluster/kmeans.hpp | 55 ++++++++++++++++++-- cpp/src/cluster/detail/kmeans_balanced.cuh | 6 ++- cpp/src/cluster/kmeans_balanced.cuh | 11 ++-- cpp/src/cluster/kmeans_balanced_fit_float.cu | 7 +-- cpp/src/cluster/kmeans_balanced_fit_half.cu | 7 +-- cpp/src/cluster/kmeans_balanced_fit_int8.cu | 7 +-- cpp/src/cluster/kmeans_balanced_fit_uint8.cu | 7 +-- cpp/src/cluster/kmeans_cluster_cost.cu | 18 +++++++ 9 files changed, 99 insertions(+), 23 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 622f52ff44..57b6282c20 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -64,8 +64,8 @@ void _fit(cuvsResources_t res, } else { auto kmeans_params = convert_balanced_params(params); T inertia_temp; - auto inertia_view = raft::make_host_scalar_view(&inertia_temp); - cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, X_view, inertia_view); + auto inertia_view = raft::make_host_scalar_view(&inertia_temp); + cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, cuvs::core::from_dlpack(X_tensor), cuvs::core::from_dlpack(centroids_tensor), std::make_optional(inertia_view)); *inertia = inertia_temp; *n_iter = params.hierarchical_n_iters; } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 4ede31243b..ac7a9ddba7 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -11,6 +11,8 @@ #include #include +#include + namespace cuvs::cluster::kmeans { /** Base structure for parameters that are common to all k-means algorithms */ @@ -424,7 +426,8 @@ void fit(raft::resources const& handle, void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids); + raft::device_matrix_view centroids, + std::optional> inertia = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -458,7 +461,8 @@ void fit(const raft::resources& handle, void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids); + raft::device_matrix_view centroids, + std::optional> inertia = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -492,7 +496,8 @@ void fit(const raft::resources& handle, void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids); + raft::device_matrix_view centroids, + std::optional> inertia = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -526,7 +531,8 @@ void fit(const raft::resources& handle, void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids); + raft::device_matrix_view centroids, + std::optional> inertia = std::nullopt); /** * @brief Predict the closest cluster each sample in X belongs to. @@ -1422,6 +1428,47 @@ void cluster_cost( raft::host_scalar_view cost, std::optional> sample_weight = std::nullopt); +/** + * @brief Compute (optionally weighted) cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * @param[in] sample_weight Optional per-sample weights. + * [len = n_samples] + */ +void cluster_cost( + const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt); + +/** + * @brief Compute (optionally weighted) cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * @param[in] sample_weight Optional per-sample weights. + * [len = n_samples] + */ +void cluster_cost( + const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt); /** * @} */ diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a45493ffb9..6fd2bfe844 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1025,7 +1025,8 @@ void build_hierarchical(const raft::resources& handle, MathT* cluster_centers, IdxT n_clusters, MappingOpT mapping_op, - const MathT* dataset_norm = nullptr) + const MathT* dataset_norm = nullptr, + MathT* inertia = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -1172,7 +1173,8 @@ void build_hierarchical(const raft::resources& handle, reinterpret_cast(dataset), n_rows, dim); auto centroids_view = raft::make_device_matrix_view(cluster_centers, n_clusters, dim); - cuvs::cluster::kmeans::detail::cluster_cost(handle, X_view, centroids_view, *inertia); + cuvs::cluster::kmeans::cluster_cost( + handle, X_view, centroids_view, raft::make_host_scalar_view(inertia)); } } } diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index d3a85b1d94..28bb9a73dd 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -67,7 +67,8 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) + MappingOpT mapping_op = raft::identity_op(), + std::optional> inertia_out = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -78,6 +79,8 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); + MathT* inertia_ptr = inertia_out.has_value() ? inertia_out.value().data_handle() : nullptr; + cuvs::cluster::kmeans::detail::build_hierarchical(handle, params, X.extent(1), @@ -85,7 +88,9 @@ void fit(const raft::resources& handle, X.extent(0), centroids.data_handle(), centroids.extent(0), - mapping_op); + mapping_op, + static_cast(nullptr), + inertia_ptr); } /** diff --git a/cpp/src/cluster/kmeans_balanced_fit_float.cu b/cpp/src/cluster/kmeans_balanced_fit_float.cu index 37a6bb127c..f3ef94b7be 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_float.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_float.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,9 +14,10 @@ namespace cuvs::cluster::kmeans { void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids) + raft::device_matrix_view centroids, + std::optional> inertia) { cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}); + handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_half.cu b/cpp/src/cluster/kmeans_balanced_fit_half.cu index e554930293..7272e6087a 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_half.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_half.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,9 +14,10 @@ namespace cuvs::cluster::kmeans { void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids) + raft::device_matrix_view centroids, + std::optional> inertia) { cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}); + handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_int8.cu index a46f7bcc94..3615c4675b 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_int8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_int8.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,9 +14,10 @@ namespace cuvs::cluster::kmeans { void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids) + raft::device_matrix_view centroids, + std::optional> inertia) { cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}); + handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu index 8395dd107f..2a7211e48e 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,9 +14,10 @@ namespace cuvs::cluster::kmeans { void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, - raft::device_matrix_view centroids) + raft::device_matrix_view centroids, + std::optional> inertia) { cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}); + handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_cluster_cost.cu b/cpp/src/cluster/kmeans_cluster_cost.cu index 3c877f4913..0cdc182fb9 100644 --- a/cpp/src/cluster/kmeans_cluster_cost.cu +++ b/cpp/src/cluster/kmeans_cluster_cost.cu @@ -25,4 +25,22 @@ void cluster_cost(const raft::resources& handle, { cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); } + +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight) +{ + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); +} + +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight) +{ + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); +} } // namespace cuvs::cluster::kmeans From 4ccce8314df63db36099e7d1c2732d3f5603dc0e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 12:40:13 -0800 Subject: [PATCH 48/81] fix all the docs --- cpp/include/cuvs/cluster/kmeans.hpp | 8 ++++++ cpp/src/cluster/detail/kmeans_balanced.cuh | 29 +++++++++++----------- cpp/src/cluster/kmeans_balanced.cuh | 6 +++-- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index ac7a9ddba7..a839cecf56 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -422,6 +422,8 @@ void fit(raft::resources const& handle, * kmeans algorithm are stored at the address * pointed by 'centroids'. * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, @@ -457,6 +459,8 @@ void fit(const raft::resources& handle, * kmeans algorithm are stored at the address * pointed by 'centroids'. * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, @@ -492,6 +496,8 @@ void fit(const raft::resources& handle, * kmeans algorithm are stored at the address * pointed by 'centroids'. * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, @@ -527,6 +533,8 @@ void fit(const raft::resources& handle, * kmeans algorithm are stored at the address * pointed by 'centroids'. * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 6fd2bfe844..cab01221e7 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -999,22 +999,23 @@ auto build_fine_clusters(const raft::resources& handle, /** * @brief Hierarchical balanced k-means * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type * @tparam MappingOpT type of the mapping operation * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param dim number of columns in `centers` and `dataset` - * @param[in] dataset a device pointer to the source dataset [n_rows, dim] - * @param n_rows number of rows in the input - * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] - * @param n_cluster - * @param metric the distance type - * @param mapping_op Mapping operation from T to MathT - * @param stream + * @param[in] handle The raft handle. + * @param[in] params Structure containing the hyper-parameters + * @param[in] dim Number of columns in `cluster_centers` and `dataset` + * @param[in] dataset A device pointer to the source dataset [n_rows, dim] + * @param[in] n_rows Number of rows in the input + * @param[out] cluster_centers A device pointer to the found cluster centers [n_clusters, dim] + * @param[in] n_clusters Requested number of clusters + * @param[in] mapping_op Mapping operation from T to MathT + * @param[in] dataset_norm (optional) Pre-computed L2 norms of each row in the dataset [n_rows] + * @param[out] inertia (optional) If non-null, the sum of squared distances of samples to + * their closest cluster center is written here. + * Only supported when T == MathT (float/double). */ template void build_hierarchical(const raft::resources& handle, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 28bb9a73dd..f1a2f39e21 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -61,6 +61,8 @@ namespace cuvs::cluster::kmeans_balanced { * @param[out] centroids The generated centroids [dim = n_clusters x n_features] * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic * datatype. If DataT == MathT, this must be the identity. + * @param[out] inertia (optional) Sum of squared distances of samples to their + * closest cluster center. */ template void fit(const raft::resources& handle, @@ -68,7 +70,7 @@ void fit(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, MappingOpT mapping_op = raft::identity_op(), - std::optional> inertia_out = std::nullopt) + std::optional> inertia = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -79,7 +81,7 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); - MathT* inertia_ptr = inertia_out.has_value() ? inertia_out.value().data_handle() : nullptr; + MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr; cuvs::cluster::kmeans::detail::build_hierarchical(handle, params, From c3ca46b0d736690e5102d0f38faf083c93a46255 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 14:17:13 -0800 Subject: [PATCH 49/81] style --- cpp/src/cluster/kmeans_balanced.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index f1a2f39e21..2b250b92cf 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -69,7 +69,7 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op(), + MappingOpT mapping_op = raft::identity_op(), std::optional> inertia = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), From 7d6bed8cefef13c68e25eb0e464e1f8dca0a0e62 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 14:41:03 -0800 Subject: [PATCH 50/81] rm compute_inertia --- cpp/src/cluster/detail/kmeans.cuh | 15 +++-- cpp/src/cluster/detail/kmeans_batched.cuh | 9 ++- cpp/src/cluster/detail/kmeans_common.cuh | 74 ----------------------- 3 files changed, 14 insertions(+), 84 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index f46a9beec7..cd2f89e3f5 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -443,14 +443,13 @@ void kmeans_fit_main(raft::resources const& handle, } } - inertia[0] = - compute_inertia(handle, - params, - X, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - workspace, - weight); + cuvs::cluster::kmeans::cluster_cost( + handle, + X, + raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features), + inertia, + std::make_optional(weight)); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 2c23cee91d..c56cdd583d 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -693,8 +693,13 @@ void fit(raft::resources const& handle, auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - T valid_inertia = compute_inertia( - handle, iter_params, X_valid_const, centroids_const, workspace, valid_weight_view); + T valid_inertia; + cuvs::cluster::kmeans::cluster_cost( + handle, + X_valid_const, + centroids_const, + raft::make_host_scalar_view(&valid_inertia), + valid_weight_view); RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d validation inertia=%f", seed_iter + 1, diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 5ae29d6bb2..433fa2d496 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -602,78 +602,4 @@ DataT compute_centroid_shift(raft::resources const& handle, return result; } -/** - * @brief Compute (optionally weighted) inertia for device-resident data. - * - * Computes the sum of (optionally weighted) squared distances from each sample - * to its nearest centroid. Used for final inertia reporting after training and - * for scoring validation sets during n_init selection. - */ -template -DataT compute_inertia( - raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - std::optional> sample_weight = std::nullopt) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto metric = params.metric; - - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - auto L2NormX = raft::make_device_vector(handle, n_samples); - rmm::device_uvector L2NormBuf(0, stream); - rmm::device_scalar cost(stream); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), n_features, n_samples, stream); - } - - auto mcd_view = minClusterAndDistance.view(); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - mcd_view, - raft::make_device_vector_view(L2NormX.data_handle(), n_samples), - L2NormBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - if (sample_weight) { - raft::linalg::map( - handle, - mcd_view, - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(mcd_view), - *sample_weight); - } - - cuvs::cluster::kmeans::detail::computeClusterCost(handle, - mcd_view, - workspace, - raft::make_device_scalar_view(cost.data()), - raft::value_op{}, - raft::add_op{}); - - DataT result = 0; - raft::copy(&result, cost.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - return result; -} - } // namespace cuvs::cluster::kmeans::detail From 6d072f61f7877f39adbdc2da0517d99a9c2cd151 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 14:57:37 -0800 Subject: [PATCH 51/81] fix compute_batched_host_inertia --- cpp/src/cluster/detail/kmeans_batched.cuh | 64 +++++------------------ 1 file changed, 12 insertions(+), 52 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index c56cdd583d..402cb317bc 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -458,24 +458,16 @@ void minibatch_update_centroids(raft::resources const& handle, template T compute_batched_host_inertia( raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, IdxT batch_size, raft::device_matrix_view centroids, - rmm::device_uvector& workspace, std::optional> sample_weight = std::nullopt) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); - auto metric = params.metric; IdxT effective_batch = std::min(batch_size, static_cast(n_samples)); - auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, effective_batch); - auto L2NormBatch = raft::make_device_vector(handle, effective_batch); - rmm::device_uvector L2NormBuf(0, stream); - rmm::device_scalar cost(stream); // Device buffer for per-batch weights (only used when sample_weight is provided) auto batch_weights = @@ -489,59 +481,27 @@ T compute_batched_host_inertia( IdxT current_batch_size = static_cast(data_batch.size()); auto batch_view = raft::make_device_matrix_view( data_batch.data(), current_batch_size, n_features); - auto mcd_view = raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); - } - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_view, - centroids, - mcd_view, - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size), - L2NormBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Apply sample weights to distances before summing (matching sklearn weighted inertia) + // Build optional device weight view for this batch + std::optional> batch_weight_view; if (sample_weight) { auto weight_offset = static_cast(data_batch.offset()); raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + weight_offset, current_batch_size, stream); - - raft::linalg::map( - handle, - mcd_view, - [=] __device__(const raft::KeyValuePair kvp, T wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(mcd_view), - raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size)); + batch_weight_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); } - cuvs::cluster::kmeans::detail::computeClusterCost(handle, - mcd_view, - workspace, - raft::make_device_scalar_view(cost.data()), - raft::value_op{}, - raft::add_op{}); + T batch_cost; + cuvs::cluster::kmeans::cluster_cost( + handle, + batch_view, + centroids, + raft::make_host_scalar_view(&batch_cost), + batch_weight_view); - T batch_cost = 0; - raft::copy(&batch_cost, cost.data(), 1, stream); - raft::resource::sync_stream(handle, stream); total_inertia += batch_cost; } @@ -1045,7 +1005,7 @@ void fit(raft::resources const& handle, auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); inertia[0] = compute_batched_host_inertia( - handle, iter_params, X, batch_size, centroids_const, workspace, sample_weight); + handle, X, batch_size, centroids_const, sample_weight); RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", seed_iter + 1, From 30f5ac40e2aba3837fcf1b987991789eb5e00e1e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 15:25:49 -0800 Subject: [PATCH 52/81] fix style --- cpp/src/cluster/detail/kmeans.cuh | 13 +++++----- cpp/src/cluster/detail/kmeans_batched.cuh | 31 ++++++++++------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index cd2f89e3f5..7d876f2e05 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -443,13 +443,12 @@ void kmeans_fit_main(raft::resources const& handle, } } - cuvs::cluster::kmeans::cluster_cost( - handle, - X, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - inertia, - std::make_optional(weight)); + cuvs::cluster::kmeans::cluster_cost(handle, + X, + raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features), + inertia, + std::make_optional(weight)); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 402cb317bc..5fab5ad2df 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -490,17 +490,16 @@ T compute_batched_host_inertia( sample_weight->data_handle() + weight_offset, current_batch_size, stream); - batch_weight_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); + batch_weight_view = raft::make_device_vector_view(batch_weights.data_handle(), + current_batch_size); } T batch_cost; - cuvs::cluster::kmeans::cluster_cost( - handle, - batch_view, - centroids, - raft::make_host_scalar_view(&batch_cost), - batch_weight_view); + cuvs::cluster::kmeans::cluster_cost(handle, + batch_view, + centroids, + raft::make_host_scalar_view(&batch_cost), + batch_weight_view); total_inertia += batch_cost; } @@ -654,12 +653,11 @@ void fit(raft::resources const& handle, auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); T valid_inertia; - cuvs::cluster::kmeans::cluster_cost( - handle, - X_valid_const, - centroids_const, - raft::make_host_scalar_view(&valid_inertia), - valid_weight_view); + cuvs::cluster::kmeans::cluster_cost(handle, + X_valid_const, + centroids_const, + raft::make_host_scalar_view(&valid_inertia), + valid_weight_view); RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d validation inertia=%f", seed_iter + 1, @@ -833,9 +831,8 @@ void fit(raft::resources const& handle, // Skip first step (inertia from initialization) if (n_iter[0] > 1) { // Update Exponentially Weighted Average of inertia. - T alpha = static_cast(current_batch_size * 2.0) / - static_cast(n_samples + 1); - alpha = std::min(alpha, T{1}); + T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); + alpha = std::min(alpha, T{1}); if (!ewa_initialized) { ewa_inertia = batch_inertia; From aa9a9e7f12653c61c6ddd27ff8b50574f3cd5666 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 16:16:50 -0800 Subject: [PATCH 53/81] rm minibatch --- c/include/cuvs/cluster/kmeans.h | 37 - c/src/cluster/kmeans.cpp | 7 - cpp/include/cuvs/cluster/kmeans.hpp | 68 +- .../cuvs/neighbors/graph_build_types.hpp | 163 +++++ cpp/src/cluster/detail/kmeans_batched.cuh | 668 ++---------------- cpp/src/cluster/detail/kmeans_common.cuh | 2 +- cpp/tests/cluster/kmeans.cu | 80 +-- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 7 - python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 43 -- python/cuvs/cuvs/tests/test_kmeans.py | 62 -- 10 files changed, 264 insertions(+), 873 deletions(-) create mode 100644 cpp/include/cuvs/neighbors/graph_build_types.hpp diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index f01395aaa4..d7b1bb9a4c 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -36,21 +36,6 @@ typedef enum { Array = 2 } cuvsKMeansInitMethod; -/** - * @brief Centroid update mode for k-means algorithm - */ -typedef enum { - /** - * Standard k-means (Lloyd's algorithm): accumulate assignments over the - * entire dataset, then update centroids once per iteration. - */ - CUVS_KMEANS_UPDATE_FULL_BATCH = 0, - - /** - * Mini-batch k-means: update centroids after each randomly sampled batch. - */ - CUVS_KMEANS_UPDATE_MINI_BATCH = 1 -} cuvsKMeansCentroidUpdateMode; /** * @brief Hyper-parameters for the kmeans algorithm @@ -106,13 +91,6 @@ struct cuvsKMeansParams { */ int batch_centroids; - /** - * Centroid update mode: - * - CUVS_KMEANS_UPDATE_FULL_BATCH: Standard Lloyd's algorithm, update after full dataset pass - * - CUVS_KMEANS_UPDATE_MINI_BATCH: Mini-batch k-means, update after each batch - */ - cuvsKMeansCentroidUpdateMode update_mode; - /** Check inertia during iterations for early convergence. */ bool inertia_check; @@ -122,21 +100,6 @@ struct cuvsKMeansParams { */ bool final_inertia_check; - /** - * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia - * before early stopping. Only used when update_mode is CUVS_KMEANS_UPDATE_MINI_BATCH. - * If 0, this convergence criterion is disabled. - */ - int max_no_improvement; - - /** - * Control the fraction of the maximum number of counts for a center to be reassigned. - * Centers with count < reassignment_ratio * max(counts) are randomly reassigned to - * observations from the current batch. Only used when update_mode is CUVS_KMEANS_UPDATE_MINI_BATCH. - * If 0.0, reassignment is disabled. Default: 0.01 - */ - double reassignment_ratio; - /** * Whether to use hierarchical (balanced) kmeans or not */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 22d65ceefc..9c976f1c28 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -29,10 +29,6 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; kmeans_params.batched.final_inertia_check = params.final_inertia_check; - kmeans_params.batched.minibatch.max_no_improvement = params.max_no_improvement; - kmeans_params.batched.minibatch.reassignment_ratio = params.reassignment_ratio; - kmeans_params.batched.update_mode = - static_cast(params.update_mode); return kmeans_params; } @@ -257,11 +253,8 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, - .update_mode = static_cast(cpp_params.batched.update_mode), .inertia_check = cpp_params.inertia_check, .final_inertia_check = cpp_params.batched.final_inertia_check, - .max_no_improvement = cpp_params.batched.minibatch.max_no_improvement, - .reassignment_ratio = cpp_params.batched.minibatch.reassignment_ratio, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; }); diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 1d20608068..00836d1002 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -50,24 +50,6 @@ struct params : base_params { Array }; - /** - * Centroid update mode determines when centroids are updated during training. - * This is primarily used with fit_batched() for out-of-core / host data processing. - */ - enum CentroidUpdateMode { - /** - * Standard k-means (Lloyd's algorithm): accumulate partial sums over the - * entire dataset, then update centroids once per iteration. - */ - FullBatch, - - /** - * Mini-batch k-means: update centroids incrementally after each randomly - * sampled batch using an online learning rule. Converges faster but may - * produce slightly different results each run. - */ - MiniBatch - }; /** * The number of clusters to form as well as the number of centroids to generate (default:8). @@ -105,9 +87,7 @@ struct params : base_params { raft::random::RngState rng_state{0}; /** - * Number of instance k-means algorithm will be run with different seeds. For MiniBatch mode, - * this is the number of different initializations to try, but the algorithm is only run once with - * the best initialization. + * Number of instance k-means algorithm will be run with different seeds. */ int n_init = 1; @@ -140,16 +120,6 @@ struct params : base_params { * These parameters are only used when calling fit_batched() and are ignored by regular fit(). */ struct batched_params { - /** - * Centroid update mode for fit_batched(): - * - FullBatch (default): Standard Lloyd's algorithm. Accumulate partial sums - * across all batches, update centroids once per iteration. Deterministic and - * mathematically equivalent to standard k-means. - * - MiniBatch: Online mini-batch k-means. Update centroids incrementally after - * each randomly sampled batch. Faster convergence but non-deterministic. - */ - CentroidUpdateMode update_mode = FullBatch; - /** * If true, compute the final inertia after fit_batched completes. This requires an additional * full pass over all the host data, which can be expensive for large datasets. @@ -157,32 +127,6 @@ struct params : base_params { * Default: false (skip final inertia computation for performance). */ bool final_inertia_check = false; - - /** - * Parameters specific to mini-batch k-means mode. - * These parameters are only used when update_mode is MiniBatch. - */ - struct minibatch_params { - /** - * Maximum number of consecutive mini-batch steps without improvement in smoothed inertia - * before early stopping. Only used when update_mode is MiniBatch. - * If 0, this convergence criterion is disabled. - * Default: 10 - */ - int max_no_improvement = 10; - - /** - * Control the fraction of the maximum number of counts for a center to be reassigned. - * Centers with count < reassignment_ratio * max(counts) are randomly reassigned to - * observations from the current batch. A higher value means that low count centers are - * more likely to be reassigned, which means that the model will take longer to converge, - * but should converge in a better clustering. - * Only used when update_mode is MiniBatch. - * If 0.0, reassignment is disabled. - * Default: 0.01 (matching scikit-learn) - */ - double reassignment_ratio = 0.01; - } minibatch; } batched; }; @@ -235,7 +179,6 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * raft::resources handle; * cuvs::cluster::kmeans::params params; * params.n_clusters = 100; - * // params.batched.update_mode = kmeans::params::MiniBatch; // for mini-batch mode * int n_features = 15; * float inertia; * int n_iter; @@ -258,8 +201,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @endcode * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. Use params.batched.update_mode - * to select FullBatch or MiniBatch mode. + * @param[in] params Parameters for KMeans model. * @param[in] X Training instances on HOST memory. The data must * be in row-major format. * [dim = n_samples x n_features] @@ -289,7 +231,7 @@ void fit_batched(raft::resources const& handle, * @brief Find clusters with k-means algorithm using batched processing. * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). + * @param[in] params Parameters for KMeans model. * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -312,7 +254,7 @@ void fit_batched(raft::resources const& handle, * @brief Find clusters with k-means algorithm using batched processing. * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). + * @param[in] params Parameters for KMeans model. * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. @@ -335,7 +277,7 @@ void fit_batched(raft::resources const& handle, * @brief Find clusters with k-means algorithm using batched processing. * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model (including update_mode). + * @param[in] params Parameters for KMeans model. * @param[in] X Training instances on HOST memory. * [dim = n_samples x n_features] * @param[in] batch_size Number of samples to process per batch. diff --git a/cpp/include/cuvs/neighbors/graph_build_types.hpp b/cpp/include/cuvs/neighbors/graph_build_types.hpp new file mode 100644 index 0000000000..5faa08c1d1 --- /dev/null +++ b/cpp/include/cuvs/neighbors/graph_build_types.hpp @@ -0,0 +1,163 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +namespace cuvs::neighbors { + +/** + * @defgroup neighbors_build_algo Graph build algorithm types + * @{ + */ + +enum GRAPH_BUILD_ALGO { BRUTE_FORCE = 0, IVF_PQ = 1, NN_DESCENT = 2, ACE = 3 }; + +namespace graph_build_params { + +/** Specialized parameters utilizing IVF-PQ to build knn graph */ +struct ivf_pq_params { + cuvs::neighbors::ivf_pq::index_params build_params; + cuvs::neighbors::ivf_pq::search_params search_params; + float refinement_rate = 1.0; + + ivf_pq_params() = default; + + /** + * Set default parameters based on shape of the input dataset. + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * raft::resources res; + * // create index_params for a [N. D] dataset + * auto dataset = raft::make_device_matrix(res, N, D); + * auto pq_params = + * graph_build_params::ivf_pq_params(dataset.extents()); + * // modify/update index_params as needed + * pq_params.kmeans_trainset_fraction = 0.1; + * @endcode + */ + ivf_pq_params(raft::matrix_extent dataset_extents, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded) + { + build_params = ivf_pq::index_params::from_dataset(dataset_extents, metric); + auto n_rows = dataset_extents.extent(0); + auto n_features = dataset_extents.extent(1); + if (n_features <= 32) { + build_params.pq_dim = 16; + build_params.pq_bits = 8; + } else { + build_params.pq_bits = 4; + if (n_features <= 64) { + build_params.pq_dim = 32; + } else if (n_features <= 128) { + build_params.pq_dim = 64; + } else if (n_features <= 192) { + build_params.pq_dim = 96; + } else { + build_params.pq_dim = raft::round_up_safe(n_features / 2, 128); + } + } + + build_params.n_lists = std::max(1, n_rows / 2000); + build_params.kmeans_n_iters = 10; + + const double kMinPointsPerCluster = 32; + const double min_kmeans_trainset_points = kMinPointsPerCluster * build_params.n_lists; + const double max_kmeans_trainset_fraction = 1.0; + const double min_kmeans_trainset_fraction = + std::min(max_kmeans_trainset_fraction, min_kmeans_trainset_points / n_rows); + build_params.kmeans_trainset_fraction = std::clamp( + 1.0 / std::sqrt(n_rows * 1e-5), min_kmeans_trainset_fraction, max_kmeans_trainset_fraction); + build_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + + search_params = cuvs::neighbors::ivf_pq::search_params{}; + search_params.n_probes = std::round(std::sqrt(build_params.n_lists) / 20 + 4); + search_params.lut_dtype = CUDA_R_16F; + search_params.internal_distance_dtype = CUDA_R_16F; + search_params.coarse_search_dtype = CUDA_R_16F; + search_params.max_internal_batch_size = 128 * 1024; + + refinement_rate = 1; + } +}; + +using nn_descent_params = cuvs::neighbors::nn_descent::index_params; + +struct brute_force_params { + cuvs::neighbors::brute_force::index_params build_params; + cuvs::neighbors::brute_force::search_params search_params; +}; + +/** Specialized parameters for ACE (Augmented Core Extraction) graph build */ +struct ace_params { + /** + * Number of partitions for ACE (Augmented Core Extraction) partitioned build. + * + * When set to 0 (default), the number of partitions is automatically derived + * based on available host and GPU memory to maximize partition size while + * ensuring the build fits in memory. + * + * Small values might improve recall but potentially degrade performance and + * increase memory usage. Partitions should not be too small to prevent issues + * in KNN graph construction. The partition size is on average 2 * (n_rows / npartitions) * dim * + * sizeof(T). 2 is because of the core and augmented vectors. Please account for imbalance in the + * partition sizes (up to 3x in our tests). + * + * If the specified number of partitions results in partitions that exceed + * available memory, the value will be automatically increased to fit memory + * constraints and a warning will be issued. + */ + size_t npartitions = 0; + /** + * The index quality for the ACE build. + * + * Bigger values increase the index quality. At some point, increasing this will no longer improve + * the quality. + */ + size_t ef_construction = 120; + /** + * Directory to store ACE build artifacts (e.g., KNN graph, optimized graph). + * + * Used when `use_disk` is true or when the graph does not fit in host and GPU + * memory. This should be the fastest disk in the system and hold enough space + * for twice the dataset, final graph, and label mapping. + */ + std::string build_dir = "/tmp/ace_build"; + /** + * Whether to use disk-based storage for ACE build. + * + * When true, enables disk-based operations for memory-efficient graph construction. + */ + bool use_disk = false; + /** + * Maximum host memory to use for ACE build in GiB. + * + * When set to 0 (default), uses available host memory. + * When set to a positive value, limits host memory usage to the specified amount. + * Useful for testing or when running alongside other memory-intensive processes. + */ + double max_host_memory_gb = 0; + /** + * Maximum GPU memory to use for ACE build in GiB. + * + * When set to 0 (default), uses available GPU memory. + * When set to a positive value, limits GPU memory usage to the specified amount. + * Useful for testing or when running alongside other memory-intensive processes. + */ + double max_gpu_memory_gb = 0; + + ace_params() = default; +}; + +// **** Experimental **** +using iterative_search_params = cuvs::neighbors::search_params; +} // namespace graph_build_params + +/** @} */ // end group neighbors_build_algo +} // namespace cuvs::neighbors diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 5fab5ad2df..3b9cd7fd4b 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -23,21 +23,12 @@ #include #include #include -#include -#include #include #include #include #include -#include -#include -#include -#include -#include -#include - #include #include #include @@ -205,246 +196,6 @@ void accumulate_batch_centroids( stream); } -/** - * @brief Update centroids using mini-batch online learning - * - * Updates centroids using the following formula (matching scikit-learn's implementation): - * - * centroid_new[k] = (centroid_old[k] * old_total_counts[k] + batch_sums[k]) / total_counts[k] - * - * This is equivalent to the learning rate formula: - * learning_rate[k] = batch_counts[k] / total_counts[k] - * centroid[k] = centroid[k] * (1 - learning_rate[k]) + batch_means[k] * learning_rate[k] - * - * Optionally reassigns low-count clusters to random samples from the current batch. - */ -template -void minibatch_update_centroids(raft::resources const& handle, - raft::device_matrix_view centroids, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts, - raft::device_vector_view total_counts, - raft::device_matrix_view batch_data, - double reassignment_ratio, - IdxT current_batch_size, - std::mt19937& rng) -{ - auto n_clusters = centroids.extent(0); - auto n_features = centroids.extent(1); - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - raft::linalg::matrix_vector_op(handle, - raft::make_const_mdspan(centroids), - raft::make_const_mdspan(total_counts), - centroids, - raft::mul_op{}); - - raft::linalg::add( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(batch_sums), centroids); - - raft::linalg::add(handle, raft::make_const_mdspan(total_counts), batch_counts, total_counts); - - raft::linalg::matrix_vector_op(handle, - raft::make_const_mdspan(centroids), - raft::make_const_mdspan(total_counts), - centroids, - raft::div_checkzero_op{}); - - // Reassignment logic: reassign low-count clusters to random samples from current batch - if (reassignment_ratio > 0.0) { - auto max_count_scalar = raft::make_device_scalar(handle, MathT{0}); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Max(nullptr, - temp_storage_bytes, - total_counts.data_handle(), - max_count_scalar.data_handle(), - n_clusters, - stream); - rmm::device_uvector temp_storage(temp_storage_bytes, stream); - cub::DeviceReduce::Max(temp_storage.data(), - temp_storage_bytes, - total_counts.data_handle(), - max_count_scalar.data_handle(), - n_clusters, - stream); - auto max_count_host = raft::make_host_scalar(0); - raft::copy(max_count_host.data_handle(), max_count_scalar.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); - MathT max_count = max_count_host.data_handle()[0]; - - MathT threshold = static_cast(reassignment_ratio) * max_count; - auto reassign_flags = raft::make_device_vector(handle, n_clusters); - - raft::linalg::unaryOp( - reassign_flags.data_handle(), - total_counts.data_handle(), - n_clusters, - [=] __device__(MathT count) { - return (count < threshold || count == MathT{0}) ? uint8_t{1} : uint8_t{0}; - }, - stream); - - auto num_reassign_scalar = raft::make_device_scalar(handle, IdxT{0}); - raft::linalg::mapThenSumReduce(num_reassign_scalar.data_handle(), - n_clusters, - raft::identity_op{}, - stream, - reassign_flags.data_handle()); - auto num_reassign_host = raft::make_host_scalar(0); - raft::copy(num_reassign_host.data_handle(), num_reassign_scalar.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); - IdxT num_to_reassign = num_reassign_host.data_handle()[0]; - - // Limit to 50% of batch size - IdxT max_reassign = static_cast(0.5 * current_batch_size); - if (num_to_reassign > max_reassign) { - // Need to select only the worst ones - do sorting on device - // First, get all cluster indices that need reassignment - auto all_reassign_indices = raft::make_device_vector(handle, num_to_reassign); - auto counting_iter = thrust::counting_iterator(0); - thrust::device_ptr flags_ptr(reassign_flags.data_handle()); - - thrust::copy_if(raft::resource::get_thrust_policy(handle), - counting_iter, - counting_iter + n_clusters, - flags_ptr, - thrust::device_pointer_cast(all_reassign_indices.data_handle()), - [] __device__(uint8_t flag) { return flag == 1; }); - - auto reassign_counts = raft::make_device_vector(handle, num_to_reassign); - auto total_counts_matrix_view = - raft::make_device_matrix_view(total_counts.data_handle(), n_clusters, 1); - auto reassign_indices_view = raft::make_device_vector_view( - all_reassign_indices.data_handle(), num_to_reassign); - auto reassign_counts_matrix_view = raft::make_device_matrix_view( - reassign_counts.data_handle(), num_to_reassign, 1); - raft::matrix::gather( - handle, total_counts_matrix_view, reassign_indices_view, reassign_counts_matrix_view); - - thrust::sort_by_key(raft::resource::get_thrust_policy(handle), - reassign_counts.data_handle(), - reassign_counts.data_handle() + num_to_reassign, - all_reassign_indices.data_handle()); - - raft::matrix::fill(handle, reassign_flags.view(), uint8_t{0}); - - // Set flags only for worst max_reassign clusters - auto worst_indices = raft::make_device_vector(handle, max_reassign); - raft::copy( - worst_indices.data_handle(), all_reassign_indices.data_handle(), max_reassign, stream); - - auto flags_scatter = raft::make_device_vector(handle, max_reassign); - raft::matrix::fill(handle, flags_scatter.view(), uint8_t{1}); - thrust::scatter(raft::resource::get_thrust_policy(handle), - flags_scatter.data_handle(), - flags_scatter.data_handle() + max_reassign, - worst_indices.data_handle(), - reassign_flags.data_handle()); - - num_to_reassign = max_reassign; - } - - if (num_to_reassign > 0) { - // Get list of cluster indices to reassign - auto reassign_indices = raft::make_device_vector(handle, num_to_reassign); - auto counting_iter = thrust::counting_iterator(0); - thrust::device_ptr flags_ptr(reassign_flags.data_handle()); - - thrust::copy_if(raft::resource::get_thrust_policy(handle), - counting_iter, - counting_iter + n_clusters, - flags_ptr, - thrust::device_pointer_cast(reassign_indices.data_handle()), - [] __device__(uint8_t flag) { return flag == 1; }); - - auto actual_count_scalar = raft::make_device_scalar(handle, IdxT{0}); - raft::linalg::mapThenSumReduce(actual_count_scalar.data_handle(), - n_clusters, - raft::identity_op{}, - stream, - reassign_flags.data_handle()); - auto actual_count_host = raft::make_host_scalar(0); - raft::copy(actual_count_host.data_handle(), actual_count_scalar.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); - num_to_reassign = actual_count_host.data_handle()[0]; - - auto reassign_indices_host = raft::make_host_vector(num_to_reassign); - raft::copy(reassign_indices_host.data_handle(), - reassign_indices.data_handle(), - num_to_reassign, - stream); - raft::resource::sync_stream(handle, stream); - - // Pick random samples from current batch (without replacement) on host - std::uniform_int_distribution batch_dist(0, current_batch_size - 1); - std::unordered_set selected_indices; - selected_indices.reserve(num_to_reassign); - - while (static_cast(selected_indices.size()) < num_to_reassign) { - IdxT idx = batch_dist(rng); - selected_indices.insert(idx); - } - - std::vector new_center_indices(selected_indices.begin(), selected_indices.end()); - - for (IdxT i = 0; i < num_to_reassign; ++i) { - IdxT cluster_idx = reassign_indices_host.data_handle()[i]; - IdxT sample_idx = new_center_indices[i]; - raft::copy(centroids.data_handle() + cluster_idx * n_features, - batch_data.data_handle() + sample_idx * n_features, - n_features, - stream); - } - - // Reset total_counts for reassigned clusters to min of non-reassigned clusters. Note that - // this will affect the learning rate directly. - auto masked_counts = raft::make_device_vector(handle, n_clusters); - auto total_counts_ptr = total_counts.data_handle(); - auto reassign_flags_ptr = reassign_flags.data_handle(); - raft::linalg::map_offset(handle, masked_counts.view(), [=] __device__(IdxT k) { - if (reassign_flags_ptr[k] == 0 && total_counts_ptr[k] > MathT{0}) { - return total_counts_ptr[k]; - } - return max_count; - }); - - auto min_non_reassigned_scalar = raft::make_device_scalar(handle, max_count); - size_t min_temp_storage_bytes = 0; - cub::DeviceReduce::Min(nullptr, - min_temp_storage_bytes, - masked_counts.data_handle(), - min_non_reassigned_scalar.data_handle(), - n_clusters, - stream); - rmm::device_uvector min_temp_storage(min_temp_storage_bytes, stream); - cub::DeviceReduce::Min(min_temp_storage.data(), - min_temp_storage_bytes, - masked_counts.data_handle(), - min_non_reassigned_scalar.data_handle(), - n_clusters, - stream); - auto min_non_reassigned_host = raft::make_host_scalar(0); - raft::copy( - min_non_reassigned_host.data_handle(), min_non_reassigned_scalar.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); - MathT min_non_reassigned = min_non_reassigned_host.data_handle()[0]; - if (min_non_reassigned == max_count) { - min_non_reassigned = MathT{1}; // Fallback if all clusters were reassigned - } - - // Update total_counts on device for reassigned clusters - // reassign_indices_host is already available from earlier - for (IdxT i = 0; i < num_to_reassign; ++i) { - IdxT cluster_idx = reassign_indices_host.data_handle()[i]; - raft::copy(total_counts.data_handle() + cluster_idx, &min_non_reassigned, 1, stream); - } - - RAFT_LOG_DEBUG("KMeans minibatch: Reassigned %zu cluster centers", - static_cast(num_to_reassign)); - } - } -} - /** * @brief Compute total inertia over host data using batched GPU processing. * @@ -508,7 +259,10 @@ T compute_batched_host_inertia( } /** - * @brief Main fit function for batched k-means with host data + * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). + * + * Processes host data in GPU-sized batches per iteration, accumulating partial centroid + * sums and counts, then finalizes centroids at the end of each iteration. * * @tparam T Input data type (float, double) * @tparam IdxT Index type (int, int64_t) @@ -516,15 +270,11 @@ T compute_batched_host_inertia( * @param[in] handle RAFT resources handle * @param[in] params K-means parameters * @param[in] X Input data on HOST [n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch + * @param[in] batch_size Number of samples to process per GPU batch * @param[in] sample_weight Optional weights per sample (on host) * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run - * - * @note For mini-batch mode: When sample weights are provided, they are used as sampling - * probabilities (normalized) to select minibatch samples. Unit weights are then passed - * to the centroid update to avoid double weighting (matching scikit-learn's approach). */ template void fit(raft::resources const& handle, @@ -558,10 +308,6 @@ void fit(raft::resources const& handle, rmm::device_uvector workspace(0, stream); - bool use_minibatch = - (params.batched.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); - RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); - auto n_init = params.n_init; if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { RAFT_LOG_DEBUG( @@ -592,97 +338,7 @@ void fit(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); - // Mini-batch only buffers - auto total_counts = raft::make_device_vector(handle, use_minibatch ? n_clusters : 0); - auto host_batch_buffer = use_minibatch ? raft::make_host_matrix(batch_size, n_features) - : raft::make_host_matrix(0, n_features); - auto batch_indices = use_minibatch ? raft::make_host_vector(batch_size) - : raft::make_host_vector(0); - auto prev_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // Weighted sampling (shared across n_init, weights are constant) - std::discrete_distribution weighted_dist; - bool use_weighted_sampling = false; - if (use_minibatch && sample_weight) { - std::vector weights(sample_weight->data_handle(), - sample_weight->data_handle() + n_samples); - weighted_dist = std::discrete_distribution(weights.begin(), weights.end()); - use_weighted_sampling = true; - } - - // n_init only selects the best *initialization using a validation set. The full training loop - // runs once with the best init. - bool minibatch_init_done = false; - if (use_minibatch && n_init > 1) { - size_t valid_size = - std::min(static_cast(n_samples), - std::max(static_cast(3 * n_clusters), static_cast(10000))); - RAFT_LOG_DEBUG("KMeans minibatch: creating validation set of %zu samples for n_init selection", - valid_size); - - auto X_valid = raft::make_device_matrix(handle, valid_size, n_features); - auto valid_weights = raft::make_device_vector(handle, sample_weight ? valid_size : 0); - std::optional> valid_weight_view; - - if (sample_weight) { - prepare_init_sample(handle, - X, - X_valid.view(), - gen(), - std::optional>{*sample_weight}, - std::optional>{valid_weights.view()}); - valid_weight_view = raft::make_device_vector_view( - valid_weights.data_handle(), static_cast(valid_size)); - } else { - prepare_init_sample(handle, X, X_valid.view(), gen()); - } - - auto X_valid_const = raft::make_const_mdspan(X_valid.view()); - - for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { - cuvs::cluster::kmeans::params iter_params = params; - iter_params.rng_state.seed = gen(); - - RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d init selection (seed=%llu)", - seed_iter + 1, - n_init, - (unsigned long long)iter_params.rng_state.seed); - - init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); - - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - T valid_inertia; - cuvs::cluster::kmeans::cluster_cost(handle, - X_valid_const, - centroids_const, - raft::make_host_scalar_view(&valid_inertia), - valid_weight_view); - - RAFT_LOG_DEBUG("KMeans minibatch: n_init %d/%d validation inertia=%f", - seed_iter + 1, - n_init, - static_cast(valid_inertia)); - - if (valid_inertia < best_inertia) { - best_inertia = valid_inertia; - raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - } - } - - raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); - - best_inertia = std::numeric_limits::max(); - n_init = 1; - compute_final_inertia = params.batched.final_inertia_check; - - minibatch_init_done = true; - RAFT_LOG_DEBUG("KMeans minibatch: best initialization selected, proceeding with training"); - } - // ---- Main n_init loop ---- - // For full-batch: runs full training n_init times, keeps best result. - // For minibatch: runs once (n_init was set to 1 above after init selection). for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = params; iter_params.rng_state.seed = gen(); @@ -692,78 +348,56 @@ void fit(raft::resources const& handle, n_init, (unsigned long long)iter_params.rng_state.seed); - if (!minibatch_init_done && - iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); } // Reset per-iteration state T priorClusteringCost = 0; - IdxT n_steps = iter_params.max_iter; - - std::mt19937 rng(iter_params.rng_state.seed); - std::uniform_int_distribution uniform_dist(0, n_samples - 1); - T ewa_inertia = T{0}; - T ewa_inertia_min = T{0}; - int no_improvement = 0; - bool ewa_initialized = false; - - if (use_minibatch) { - raft::matrix::fill(handle, total_counts.view(), T{0}); - raft::matrix::fill(handle, batch_weights.view(), T{1}); - n_steps = (iter_params.max_iter * n_samples) / batch_size; - } - for (n_iter[0] = 1; n_iter[0] <= n_steps; ++n_iter[0]) { + for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); T total_cost = 0; - if (use_minibatch) { - IdxT current_batch_size = batch_size; + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); - for (IdxT i = 0; i < current_batch_size; ++i) { - batch_indices.data_handle()[i] = - use_weighted_sampling ? weighted_dist(rng) : uniform_dist(rng); - } + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); -#pragma omp parallel for - for (IdxT i = 0; i < current_batch_size; ++i) { - IdxT sample_idx = batch_indices.data_handle()[i]; - std::memcpy(host_batch_buffer.data_handle() + i * n_features, - X.data_handle() + sample_idx * n_features, - n_features * sizeof(T)); - } + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, batch_size, stream); - raft::copy(batch_data.data_handle(), - host_batch_buffer.data_handle(), - current_batch_size * n_features, - stream); + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); auto batch_data_view = raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view_const = raft::make_device_vector_view( + data_batch.data(), current_batch_size, n_features); + + auto batch_weights_fill_view = + raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + data_batch.offset(), + current_batch_size, + stream); + } else { + raft::matrix::fill(handle, batch_weights_fill_view, T{1}); + } + + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormBatch.data_handle(), - batch_data.data_handle(), - n_features, - current_batch_size, - stream); + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); } - // Save centroids before update for convergence check - raft::copy(prev_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); @@ -771,7 +405,7 @@ void fit(raft::resources const& handle, handle, batch_data_view, centroids_const, - minClusterAndDistance_view, + minClusterAndDistance.view(), L2NormBatch_const, L2NormBuf_OR_DistBuf, metric, @@ -779,222 +413,62 @@ void fit(raft::resources const& handle, params.batch_centroids, workspace); - // Compute batch inertia (normalized by batch_size for comparison) - T batch_inertia = 0; - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - batch_inertia = clusterCost_host.data_handle()[0] / static_cast(current_batch_size); - - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, cluster_counts.view(), T{0}); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, - batch_weights_view_const, + batch_weights_view, centroid_sums.view(), cluster_counts.view()); - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - - minibatch_update_centroids(handle, - centroids, - centroid_sums_const, - cluster_counts_const, - total_counts.view(), - batch_data_view, - params.batched.minibatch.reassignment_ratio, - current_batch_size, - rng); - - // Compute squared difference of centers (for convergence check) - T centers_squared_diff = - compute_centroid_shift(handle, - raft::make_device_matrix_view( - prev_centroids.data_handle(), n_clusters, n_features), - raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features)); - - // Skip first step (inertia from initialization) - if (n_iter[0] > 1) { - // Update Exponentially Weighted Average of inertia. - T alpha = static_cast(current_batch_size * 2.0) / static_cast(n_samples + 1); - alpha = std::min(alpha, T{1}); - - if (!ewa_initialized) { - ewa_inertia = batch_inertia; - ewa_inertia_min = batch_inertia; - ewa_initialized = true; - } else { - ewa_inertia = ewa_inertia * (T{1} - alpha) + batch_inertia * alpha; - } - - RAFT_LOG_DEBUG( - "KMeans minibatch step %d/%d: batch_inertia=%f, ewa_inertia=%f, " - "centers_squared_diff=%f, alpha=%f", - n_iter[0], - n_steps, - static_cast(batch_inertia), - static_cast(ewa_inertia), - static_cast(centers_squared_diff), - static_cast(alpha)); - - // Early stopping: absolute tolerance on squared change of centers - // Disabled if tol == 0.0 - if (params.tol > 0.0 && centers_squared_diff <= params.tol) { - RAFT_LOG_DEBUG("KMeans minibatch: Converged (small centers change) at step %d/%d", - n_iter[0], - n_steps); - break; - } - - // Early stopping: lack of improvement in smoothed inertia - // Disabled if max_no_improvement == 0 - if (params.batched.minibatch.max_no_improvement > 0) { - if (ewa_inertia < ewa_inertia_min) { - no_improvement = 0; - ewa_inertia_min = ewa_inertia; - } else { - no_improvement++; - } - - if (no_improvement >= params.batched.minibatch.max_no_improvement) { - RAFT_LOG_DEBUG("KMeans minibatch: Converged (lack of improvement) at step %d/%d", - n_iter[0], - n_steps); - break; - } - } - } else { - RAFT_LOG_DEBUG("KMeans minibatch step %d/%d: mean batch inertia: %f", - n_iter[0], - n_steps, - static_cast(batch_inertia)); - } - } else { - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, cluster_counts.view(), T{0}); - - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - - using namespace cuvs::spatial::knn::detail::utils; - batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, batch_size, stream); - - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - auto batch_weights_fill_view = - raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + data_batch.offset(), - current_batch_size, - stream); - } else { - raft::matrix::fill(handle, batch_weights_fill_view, T{1}); - } - - auto batch_weights_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); - } - - auto L2NormBatch_const = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( handle, - batch_data_view, - centroids_const, minClusterAndDistance.view(), - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); - - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - cluster_counts.view()); - - if (params.inertia_check) { - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - total_cost += clusterCost_host.data_handle()[0]; - } + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + auto clusterCost_host = raft::make_host_scalar(0); + raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + total_cost += clusterCost_host.data_handle()[0]; } + } - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = - raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); - finalize_centroids( - handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); - } + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); - // Convergence check for full-batch mode only - if (!use_minibatch) { - T sqrdNormError = - compute_centroid_shift(handle, - raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features), - raft::make_device_matrix_view( - new_centroids.data_handle(), n_clusters, n_features)); + // Convergence check + T sqrdNormError = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + new_centroids.data_handle(), n_clusters, n_features)); - raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); - bool done = false; - if (params.inertia_check) { - if (n_iter[0] > 1) { - T delta = total_cost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = total_cost; + bool done = false; + if (params.inertia_check) { + if (n_iter[0] > 1) { + T delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; } + priorClusteringCost = total_cost; + } - if (sqrdNormError < params.tol) done = true; + if (sqrdNormError < params.tol) done = true; - if (done) { - RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); - break; - } + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; } } diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 433fa2d496..24524b7207 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -581,7 +581,7 @@ void finalize_centroids(raft::resources const& handle, * @brief Compute the squared norm difference between two centroid sets. * * Returns sum((old_centroids - new_centroids)^2). - * Used for convergence checking in both full-batch and mini-batch modes. + * Used for convergence checking. */ template DataT compute_centroid_shift(raft::resources const& handle, diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index b95fceaedc..15f929e090 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -357,7 +357,6 @@ struct KmeansBatchedInputs { int n_clusters; T tol; bool weighted; - bool minibatch; }; template @@ -383,12 +382,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_samples, n_features); @@ -478,15 +472,13 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(T(1e-3)), - stream); - } + // FullBatch: centroids should match the device fit reference + centroids_match = devArrMatch(d_centroids_ref.data(), + d_centroids.data(), + params.n_clusters, + n_features, + CompareApprox(T(1e-3)), + stream); // Also check label quality via ARI T pred_inertia = 0; @@ -673,37 +665,21 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam // ============================================================================ const std::vector> batched_inputsf2 = { - // FullBatch mode - {1000, 32, 5, 0.0001f, true, false}, - {1000, 32, 5, 0.0001f, false, false}, - {1000, 100, 20, 0.0001f, true, false}, - {1000, 100, 20, 0.0001f, false, false}, - {10000, 32, 10, 0.0001f, true, false}, - {10000, 32, 10, 0.0001f, false, false}, - // MiniBatch mode - {1000, 32, 5, 0.0001f, true, true}, - {1000, 32, 5, 0.0001f, false, true}, - {1000, 100, 20, 0.0001f, true, true}, - {1000, 100, 20, 0.0001f, false, true}, - {10000, 32, 10, 0.0001f, true, true}, - {10000, 32, 10, 0.0001f, false, true}, + {1000, 32, 5, 0.0001f, true}, + {1000, 32, 5, 0.0001f, false}, + {1000, 100, 20, 0.0001f, true}, + {1000, 100, 20, 0.0001f, false}, + {10000, 32, 10, 0.0001f, true}, + {10000, 32, 10, 0.0001f, false}, }; const std::vector> batched_inputsd2 = { - // FullBatch mode - {1000, 32, 5, 0.0001, true, false}, - {1000, 32, 5, 0.0001, false, false}, - {1000, 100, 20, 0.0001, true, false}, - {1000, 100, 20, 0.0001, false, false}, - {10000, 32, 10, 0.0001, true, false}, - {10000, 32, 10, 0.0001, false, false}, - // MiniBatch mode - {1000, 32, 5, 0.0001, true, true}, - {1000, 32, 5, 0.0001, false, true}, - {1000, 100, 20, 0.0001, true, true}, - {1000, 100, 20, 0.0001, false, true}, - {10000, 32, 10, 0.0001, true, true}, - {10000, 32, 10, 0.0001, false, true}, + {1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, }; // ============================================================================ @@ -714,22 +690,14 @@ typedef KmeansFitBatchedTest KmeansFitBatchedTestD; TEST_P(KmeansFitBatchedTestF, Result) { - if (testparams.minibatch) { - ASSERT_TRUE(score >= 0.9); - } else { - ASSERT_TRUE(centroids_match); - ASSERT_TRUE(score == 1.0); - } + ASSERT_TRUE(centroids_match); + ASSERT_TRUE(score == 1.0); } TEST_P(KmeansFitBatchedTestD, Result) { - if (testparams.minibatch) { - ASSERT_TRUE(score >= 0.9); - } else { - ASSERT_TRUE(centroids_match); - ASSERT_TRUE(score == 1.0); - } + ASSERT_TRUE(centroids_match); + ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index f8ffc02acf..40ac632b23 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -18,10 +18,6 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: Random Array - ctypedef enum cuvsKMeansCentroidUpdateMode: - CUVS_KMEANS_UPDATE_FULL_BATCH - CUVS_KMEANS_UPDATE_MINI_BATCH - ctypedef enum cuvsKMeansType: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED @@ -36,11 +32,8 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, - cuvsKMeansCentroidUpdateMode update_mode, bool inertia_check, bool final_inertia_check, - int max_no_improvement, - double reassignment_ratio, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index c48d36a18a..b8a28957fa 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -45,12 +45,6 @@ INIT_METHOD_TYPES = { INIT_METHOD_NAMES = {v: k for k, v in INIT_METHOD_TYPES.items()} -UPDATE_MODE_TYPES = { - "full_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_FULL_BATCH, - "mini_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_MINI_BATCH} - -UPDATE_MODE_NAMES = {v: k for k, v in UPDATE_MODE_TYPES.items()} - cdef class KMeansParams: """ Hyper-parameters for the kmeans algorithm @@ -83,11 +77,6 @@ cdef class KMeansParams: [batch_samples x n_clusters]. batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. - update_mode : str - Centroid update strategy. One of: - "full_batch" : Standard Lloyd's algorithm - accumulate assignments over - the entire dataset, then update centroids once per iteration. - "mini_batch" : Mini-batch k-means - update centroids after each batch. inertia_check : bool If True, check inertia during iterations for early convergence. final_inertia_check : bool @@ -95,16 +84,6 @@ cdef class KMeansParams: This requires an additional full pass over all the host data. Only used by fit_batched(); regular fit() always computes final inertia. Default: False (skip final inertia computation for performance). - max_no_improvement : int - Maximum number of consecutive mini-batch steps without improvement - in smoothed inertia before early stopping. Only used when update_mode - is "mini_batch". If 0, this convergence criterion is disabled. - Default: 10 (matches sklearn's default). - reassignment_ratio : float - Control the fraction of the maximum number of counts for a center to be reassigned. - Centers with count < reassignment_ratio * max(counts) are randomly reassigned to - observations from the current batch. Only used when update_mode is "mini_batch". - If 0.0, reassignment is disabled. Default: 0.01 (matches sklearn's default). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -129,11 +108,8 @@ cdef class KMeansParams: oversampling_factor=None, batch_samples=None, batch_centroids=None, - update_mode=None, inertia_check=None, final_inertia_check=None, - max_no_improvement=None, - reassignment_ratio=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -155,17 +131,10 @@ cdef class KMeansParams: self.params.batch_samples = batch_samples if batch_centroids is not None: self.params.batch_centroids = batch_centroids - if update_mode is not None: - c_mode = UPDATE_MODE_TYPES[update_mode] - self.params.update_mode = c_mode if inertia_check is not None: self.params.inertia_check = inertia_check if final_inertia_check is not None: self.params.final_inertia_check = final_inertia_check - if max_no_improvement is not None: - self.params.max_no_improvement = max_no_improvement - if reassignment_ratio is not None: - self.params.reassignment_ratio = reassignment_ratio if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -210,10 +179,6 @@ cdef class KMeansParams: def batch_centroids(self): return self.params.batch_centroids - @property - def update_mode(self): - return UPDATE_MODE_NAMES[self.params.update_mode] - @property def inertia_check(self): return self.params.inertia_check @@ -222,14 +187,6 @@ cdef class KMeansParams: def final_inertia_check(self): return self.params.final_inertia_check - @property - def max_no_improvement(self): - return self.params.max_no_improvement - - @property - def reassignment_ratio(self): - return self.params.reassignment_ratio - @property def hierarchical(self): return self.params.hierarchical diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 0c577adc1f..c464666901 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -15,8 +15,6 @@ ) from cuvs.distance import pairwise_distance -from sklearn.cluster import MiniBatchKMeans - @pytest.mark.parametrize("n_rows", [100]) @pytest.mark.parametrize("n_cols", [5, 25]) @@ -126,63 +124,3 @@ def test_fit_batched_matches_fit( ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" -@pytest.mark.parametrize("n_rows", [1000]) -@pytest.mark.parametrize("n_cols", [10]) -@pytest.mark.parametrize("n_clusters", [8, 16, 32]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_minibatch_sklearn(n_rows, n_cols, n_clusters, dtype): - """ - Test that fit_batched matches sklearn's KMeans implementation. - """ - rng = np.random.default_rng(99) - X_host = rng.random((n_rows, n_cols)).astype(dtype) - norms = np.linalg.norm(X_host, ord=1, axis=1, keepdims=True) - norms = np.where(norms == 0, 1.0, norms) - X_host = X_host / norms - initial_centroids_host = X_host[:n_clusters].copy() - - # Sklearn fit - kmeans = MiniBatchKMeans( - n_clusters=n_clusters, - init=initial_centroids_host, - max_iter=100, - verbose=0, - random_state=None, - tol=1e-4, - max_no_improvement=10, - init_size=None, - n_init="auto", - reassignment_ratio=0.01, - batch_size=256, - ) - kmeans.fit(X_host) - - centroids_sklearn = kmeans.cluster_centers_ - inertia_sklearn = kmeans.inertia_ - - # cuvs fit - params = KMeansParams( - n_clusters=n_clusters, - init_method="Array", - max_iter=100, - tol=1e-4, - update_mode="mini_batch", - final_inertia_check=True, - max_no_improvement=10, - ) - centroids_cuvs, inertia_cuvs, _ = fit_batched( - params, - X_host, - batch_size=256, - centroids=device_ndarray(initial_centroids_host.copy()), - ) - centroids_cuvs = centroids_cuvs.copy_to_host() - - assert np.allclose( - centroids_sklearn, centroids_cuvs, rtol=0.1, atol=0.1 - ), f"max diff: {np.max(np.abs(centroids_sklearn - centroids_cuvs))}" - - inertia_diff = abs(inertia_sklearn - inertia_cuvs) - assert np.allclose(inertia_sklearn, inertia_cuvs, rtol=0.1, atol=0.1), ( - f"inertia diff: sklearn={inertia_sklearn}, cuvs={inertia_cuvs}, diff={inertia_diff}" - ) From 64b05848abcfbab059df26e4e9e16f760ac1ba58 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 16:21:10 -0800 Subject: [PATCH 54/81] rm extra file --- .../cuvs/neighbors/graph_build_types.hpp | 163 ------------------ 1 file changed, 163 deletions(-) delete mode 100644 cpp/include/cuvs/neighbors/graph_build_types.hpp diff --git a/cpp/include/cuvs/neighbors/graph_build_types.hpp b/cpp/include/cuvs/neighbors/graph_build_types.hpp deleted file mode 100644 index 5faa08c1d1..0000000000 --- a/cpp/include/cuvs/neighbors/graph_build_types.hpp +++ /dev/null @@ -1,163 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include -#include -#include - -namespace cuvs::neighbors { - -/** - * @defgroup neighbors_build_algo Graph build algorithm types - * @{ - */ - -enum GRAPH_BUILD_ALGO { BRUTE_FORCE = 0, IVF_PQ = 1, NN_DESCENT = 2, ACE = 3 }; - -namespace graph_build_params { - -/** Specialized parameters utilizing IVF-PQ to build knn graph */ -struct ivf_pq_params { - cuvs::neighbors::ivf_pq::index_params build_params; - cuvs::neighbors::ivf_pq::search_params search_params; - float refinement_rate = 1.0; - - ivf_pq_params() = default; - - /** - * Set default parameters based on shape of the input dataset. - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * raft::resources res; - * // create index_params for a [N. D] dataset - * auto dataset = raft::make_device_matrix(res, N, D); - * auto pq_params = - * graph_build_params::ivf_pq_params(dataset.extents()); - * // modify/update index_params as needed - * pq_params.kmeans_trainset_fraction = 0.1; - * @endcode - */ - ivf_pq_params(raft::matrix_extent dataset_extents, - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded) - { - build_params = ivf_pq::index_params::from_dataset(dataset_extents, metric); - auto n_rows = dataset_extents.extent(0); - auto n_features = dataset_extents.extent(1); - if (n_features <= 32) { - build_params.pq_dim = 16; - build_params.pq_bits = 8; - } else { - build_params.pq_bits = 4; - if (n_features <= 64) { - build_params.pq_dim = 32; - } else if (n_features <= 128) { - build_params.pq_dim = 64; - } else if (n_features <= 192) { - build_params.pq_dim = 96; - } else { - build_params.pq_dim = raft::round_up_safe(n_features / 2, 128); - } - } - - build_params.n_lists = std::max(1, n_rows / 2000); - build_params.kmeans_n_iters = 10; - - const double kMinPointsPerCluster = 32; - const double min_kmeans_trainset_points = kMinPointsPerCluster * build_params.n_lists; - const double max_kmeans_trainset_fraction = 1.0; - const double min_kmeans_trainset_fraction = - std::min(max_kmeans_trainset_fraction, min_kmeans_trainset_points / n_rows); - build_params.kmeans_trainset_fraction = std::clamp( - 1.0 / std::sqrt(n_rows * 1e-5), min_kmeans_trainset_fraction, max_kmeans_trainset_fraction); - build_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - - search_params = cuvs::neighbors::ivf_pq::search_params{}; - search_params.n_probes = std::round(std::sqrt(build_params.n_lists) / 20 + 4); - search_params.lut_dtype = CUDA_R_16F; - search_params.internal_distance_dtype = CUDA_R_16F; - search_params.coarse_search_dtype = CUDA_R_16F; - search_params.max_internal_batch_size = 128 * 1024; - - refinement_rate = 1; - } -}; - -using nn_descent_params = cuvs::neighbors::nn_descent::index_params; - -struct brute_force_params { - cuvs::neighbors::brute_force::index_params build_params; - cuvs::neighbors::brute_force::search_params search_params; -}; - -/** Specialized parameters for ACE (Augmented Core Extraction) graph build */ -struct ace_params { - /** - * Number of partitions for ACE (Augmented Core Extraction) partitioned build. - * - * When set to 0 (default), the number of partitions is automatically derived - * based on available host and GPU memory to maximize partition size while - * ensuring the build fits in memory. - * - * Small values might improve recall but potentially degrade performance and - * increase memory usage. Partitions should not be too small to prevent issues - * in KNN graph construction. The partition size is on average 2 * (n_rows / npartitions) * dim * - * sizeof(T). 2 is because of the core and augmented vectors. Please account for imbalance in the - * partition sizes (up to 3x in our tests). - * - * If the specified number of partitions results in partitions that exceed - * available memory, the value will be automatically increased to fit memory - * constraints and a warning will be issued. - */ - size_t npartitions = 0; - /** - * The index quality for the ACE build. - * - * Bigger values increase the index quality. At some point, increasing this will no longer improve - * the quality. - */ - size_t ef_construction = 120; - /** - * Directory to store ACE build artifacts (e.g., KNN graph, optimized graph). - * - * Used when `use_disk` is true or when the graph does not fit in host and GPU - * memory. This should be the fastest disk in the system and hold enough space - * for twice the dataset, final graph, and label mapping. - */ - std::string build_dir = "/tmp/ace_build"; - /** - * Whether to use disk-based storage for ACE build. - * - * When true, enables disk-based operations for memory-efficient graph construction. - */ - bool use_disk = false; - /** - * Maximum host memory to use for ACE build in GiB. - * - * When set to 0 (default), uses available host memory. - * When set to a positive value, limits host memory usage to the specified amount. - * Useful for testing or when running alongside other memory-intensive processes. - */ - double max_host_memory_gb = 0; - /** - * Maximum GPU memory to use for ACE build in GiB. - * - * When set to 0 (default), uses available GPU memory. - * When set to a positive value, limits GPU memory usage to the specified amount. - * Useful for testing or when running alongside other memory-intensive processes. - */ - double max_gpu_memory_gb = 0; - - ace_params() = default; -}; - -// **** Experimental **** -using iterative_search_params = cuvs::neighbors::search_params; -} // namespace graph_build_params - -/** @} */ // end group neighbors_build_algo -} // namespace cuvs::neighbors From ec48753d06f79640e2e25ed0ea669c36d870ce7e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 16:37:17 -0800 Subject: [PATCH 55/81] fix header includes --- cpp/src/cluster/detail/kmeans_batched.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 3b9cd7fd4b..921a70851d 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -17,11 +17,8 @@ #include #include #include -#include #include #include -#include -#include #include #include #include @@ -29,6 +26,8 @@ #include #include +#include + #include #include #include From 738eea7d71d119e88c5b68b762fb6aeb0938d974 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 17:09:36 -0800 Subject: [PATCH 56/81] address pr reviews --- c/include/cuvs/cluster/kmeans.h | 50 ++--- c/src/cluster/kmeans.cpp | 145 ++++++------- cpp/include/cuvs/cluster/kmeans.hpp | 222 ++++++++------------ cpp/src/cluster/detail/kmeans_batched.cuh | 26 ++- cpp/src/cluster/kmeans_fit_double.cu | 72 +++---- cpp/src/cluster/kmeans_fit_float.cu | 72 +++---- cpp/tests/cluster/kmeans.cu | 38 ++-- python/cuvs/cuvs/cluster/kmeans/__init__.py | 4 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 13 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 183 +++++----------- 10 files changed, 323 insertions(+), 502 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index d7b1bb9a4c..62ab09ec59 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -95,10 +95,10 @@ struct cuvsKMeansParams { bool inertia_check; /** - * Compute final inertia after fit_batched completes (requires extra data pass). - * Only used by fit_batched; regular fit always computes final inertia. + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). */ - bool final_inertia_check; + int64_t batch_size; /** * Whether to use hierarchical (balanced) kmeans or not @@ -150,18 +150,24 @@ typedef enum { CUVS_KMEANS_TYPE_KMEANS = 0, CUVS_KMEANS_TYPE_KMEANS_BALANCED = 1 * clusters are reinitialized by choosing new centroids with * k-means++ algorithm. * + * X may reside on either host (CPU) or device (GPU) memory. + * When X is on the host the data is streamed to the GPU in + * batches controlled by params->batch_size. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X Training instances to cluster. The data must - * be in row-major format. + * be in row-major format. May be on host or + * device memory. * [dim = n_samples x n_features] * @param[in] sample_weight Optional weights for each observation in X. + * Must be on the same memory space as X. * [len = n_samples] * @param[inout] centroids [in] When init is InitMethod::Array, use * centroids as the initial cluster centers. * [out] The generated centroids from the * kmeans algorithm are stored at the address - * pointed by 'centroids'. + * pointed by 'centroids'. Must be on device. * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. @@ -173,7 +179,7 @@ cuvsError_t cuvsKMeansFit(cuvsResources_t res, DLManagedTensor* sample_weight, DLManagedTensor* centroids, double* inertia, - int* n_iter); + int64_t* n_iter); /** * @brief Predict the closest cluster each sample in X belongs to. @@ -221,38 +227,6 @@ cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* centroids, double* cost); -/** - * @brief Find clusters with k-means algorithm using batched processing. - * - * This function processes data from HOST memory in batches, streaming - * to the GPU. Useful when the dataset is too large to fit in GPU memory. - * - * @param[in] res opaque C handle - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances on HOST memory. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host). - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. Must be on DEVICE memory. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, - cuvsKMeansParams_t params, - DLManagedTensor* X, - int64_t batch_size, - DLManagedTensor* sample_weight, - DLManagedTensor* centroids, - double* inertia, - int64_t* n_iter); /** * @} */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 9c976f1c28..a6c912af61 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,7 +28,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; - kmeans_params.batched.final_inertia_check = params.final_inertia_check; + kmeans_params.batched.batch_size = params.batch_size; return kmeans_params; } @@ -41,59 +41,54 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP } template -void _fit(cuvsResources_t res, - const cuvsKMeansParams& params, - DLManagedTensor* X_tensor, - DLManagedTensor* sample_weight_tensor, - DLManagedTensor* centroids_tensor, - double* inertia, - int* n_iter) +void _fit_device(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + int64_t* n_iter) { - auto X = X_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); - if (cuvs::core::is_dlpack_device_compatible(X)) { - using const_mdspan_type = raft::device_matrix_view; - using mdspan_type = raft::device_matrix_view; + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; - if (params.hierarchical) { - if (sample_weight_tensor != NULL) { - RAFT_FAIL("sample_weight cannot be used with hierarchical kmeans"); - } + if (params.hierarchical) { + if (sample_weight_tensor != NULL) { + RAFT_FAIL("sample_weight cannot be used with hierarchical kmeans"); + } - if constexpr (std::is_same_v) { - RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans"); - } else { - auto kmeans_params = convert_balanced_params(params); - T inertia_temp; - auto inertia_view = raft::make_host_scalar_view(&inertia_temp); - cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, cuvs::core::from_dlpack(X_tensor), cuvs::core::from_dlpack(centroids_tensor), std::make_optional(inertia_view)); - *inertia = inertia_temp; - *n_iter = params.hierarchical_n_iters; - } + if constexpr (std::is_same_v) { + RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans"); } else { + auto kmeans_params = convert_balanced_params(params); T inertia_temp; - IdxT n_iter_temp; - - std::optional> sample_weight; - if (sample_weight_tensor != NULL) { - sample_weight = - cuvs::core::from_dlpack>(sample_weight_tensor); - } - - auto kmeans_params = convert_params(params); - cuvs::cluster::kmeans::fit(*res_ptr, - kmeans_params, - cuvs::core::from_dlpack(X_tensor), - sample_weight, - cuvs::core::from_dlpack(centroids_tensor), - raft::make_host_scalar_view(&inertia_temp), - raft::make_host_scalar_view(&n_iter_temp)); + auto inertia_view = raft::make_host_scalar_view(&inertia_temp); + cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, cuvs::core::from_dlpack(X_tensor), cuvs::core::from_dlpack(centroids_tensor), std::make_optional(inertia_view)); *inertia = inertia_temp; - *n_iter = n_iter_temp; + *n_iter = params.hierarchical_n_iters; } } else { - RAFT_FAIL("X dataset must be accessible on device memory"); + T inertia_temp; + IdxT n_iter_temp; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + auto kmeans_params = convert_params(params); + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + *inertia = inertia_temp; + *n_iter = static_cast(n_iter_temp); } } @@ -179,10 +174,9 @@ void _cluster_cost(cuvsResources_t res, } template -void _fit_batched(cuvsResources_t res, +void _fit_host(cuvsResources_t res, const cuvsKMeansParams& params, DLManagedTensor* X_tensor, - IdxT batch_size, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, double* inertia, @@ -196,7 +190,7 @@ void _fit_batched(cuvsResources_t res, // X must be on host (CPU) memory if (X.device.device_type != kDLCPU) { - RAFT_FAIL("X dataset must be on host (CPU) memory for fit_batched"); + RAFT_FAIL("X dataset must be on host (CPU) memory for batched fit"); } // centroids must be on device memory @@ -214,7 +208,7 @@ void _fit_batched(cuvsResources_t res, if (sample_weight_tensor != NULL) { auto sw = sample_weight_tensor->dl_tensor; if (sw.device.device_type != kDLCPU) { - RAFT_FAIL("sample_weight must be on host (CPU) memory for fit_batched"); + RAFT_FAIL("sample_weight must be on host (CPU) memory for batched fit"); } sample_weight = raft::make_host_vector_view( reinterpret_cast(sw.data), n_samples); @@ -224,14 +218,13 @@ void _fit_batched(cuvsResources_t res, IdxT n_iter_temp; auto kmeans_params = convert_params(params); - cuvs::cluster::kmeans::fit_batched(*res_ptr, - kmeans_params, - X_view, - batch_size, - sample_weight, - centroids_view, - raft::make_host_scalar_view(&inertia_temp), - raft::make_host_scalar_view(&n_iter_temp)); + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + X_view, + sample_weight, + centroids_view, + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); *inertia = inertia_temp; *n_iter = n_iter_temp; @@ -254,7 +247,7 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, .inertia_check = cpp_params.inertia_check, - .final_inertia_check = cpp_params.batched.final_inertia_check, + .batch_size = cpp_params.batched.batch_size, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; }); @@ -271,14 +264,22 @@ extern "C" cuvsError_t cuvsKMeansFit(cuvsResources_t res, DLManagedTensor* sample_weight, DLManagedTensor* centroids, double* inertia, - int* n_iter) + int64_t* n_iter) { return cuvs::core::translate_exceptions([=] { - auto dataset = X->dl_tensor; + auto dataset = X->dl_tensor; + bool is_host = (dataset.device.device_type == kDLCPU); + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { - _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + if (is_host) + _fit_host(res, *params, X, sample_weight, centroids, inertia, n_iter); + else + _fit_device(res, *params, X, sample_weight, centroids, inertia, n_iter); } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { - _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + if (is_host) + _fit_host(res, *params, X, sample_weight, centroids, inertia, n_iter); + else + _fit_device(res, *params, X, sample_weight, centroids, inertia, n_iter); } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, @@ -330,25 +331,3 @@ extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, }); } -extern "C" cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, - cuvsKMeansParams_t params, - DLManagedTensor* X, - int64_t batch_size, - DLManagedTensor* sample_weight, - DLManagedTensor* centroids, - double* inertia, - int64_t* n_iter) -{ - return cuvs::core::translate_exceptions([=] { - auto dataset = X->dl_tensor; - if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { - _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); - } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { - _fit_batched(res, *params, X, batch_size, sample_weight, centroids, inertia, n_iter); - } else { - RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", - dataset.dtype.code, - dataset.dtype.bits); - } - }); -} diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 00836d1002..b07d02ace4 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -110,23 +110,22 @@ struct params : base_params { int batch_centroids = 0; /** - * If true, check inertia during iterations for early convergence (used by both fit and - * fit_batched). + * If true, check inertia during iterations for early convergence. */ bool inertia_check = false; /** - * Parameters specific to batched k-means (fit_batched). - * These parameters are only used when calling fit_batched() and are ignored by regular fit(). + * Parameters specific to batched k-means (host-data overloads of fit/predict). + * These parameters are only used when calling fit() or predict() with host_matrix_view data + * and are ignored by the device_matrix_view overloads. */ struct batched_params { /** - * If true, compute the final inertia after fit_batched completes. This requires an additional - * full pass over all the host data, which can be expensive for large datasets. - * Only used by fit_batched(); regular fit() always computes final inertia. - * Default: false (skip final inertia computation for performance). + * Number of samples to process per GPU batch. + * When set to 0, a default batch size will be chosen automatically. + * Default: 0 (auto). */ - bool final_inertia_check = false; + int64_t batch_size = 0; } batched; }; @@ -166,10 +165,11 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; */ /** - * @brief Find clusters with k-means algorithm using batched processing. + * @brief Find clusters with k-means algorithm using batched processing of host data. * - * This version supports out-of-core computation where the dataset resides - * on the host. Data is processed in batches, streaming from host to device. + * This overload supports out-of-core computation where the dataset resides + * on the host. Data is processed in GPU-sized batches, streaming from host to device. + * The batch size is controlled by params.batched.batch_size. * * @code{.cpp} * #include @@ -179,6 +179,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * raft::resources handle; * cuvs::cluster::kmeans::params params; * params.n_clusters = 100; + * params.batched.batch_size = 100000; * int n_features = 15; * float inertia; * int n_iter; @@ -190,22 +191,21 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * // Centroids on device * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * - * kmeans::fit_batched(handle, - * params, - * X, - * 100000, // batch_size - * std::nullopt, - * centroids.view(), - * raft::make_host_scalar_view(&inertia), - * raft::make_host_scalar_view(&n_iter)); + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * raft::make_host_scalar_view(&inertia), + * raft::make_host_scalar_view(&n_iter)); * @endcode * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. + * @param[in] params Parameters for KMeans model. Batch size is read from + * params.batched.batch_size. * @param[in] X Training instances on HOST memory. The data must * be in row-major format. * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. * @param[in] sample_weight Optional weights for each observation in X (on host). * [len = n_samples] * @param[inout] centroids [in] When init is InitMethod::Array, use @@ -218,86 +218,49 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * closest cluster center. * @param[out] n_iter Number of iterations run. */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** - * @brief Find clusters with k-means algorithm using batched processing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances on HOST memory. - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host). - * @param[inout] centroids Cluster centers on device. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. + * @brief Find clusters with k-means algorithm using batched processing of host data. */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** - * @brief Find clusters with k-means algorithm using batched processing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances on HOST memory. - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host). - * @param[inout] centroids Cluster centers on device. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. + * @brief Find clusters with k-means algorithm using batched processing of host data. */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** - * @brief Find clusters with k-means algorithm using batched processing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances on HOST memory. - * [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. - * @param[in] sample_weight Optional weights for each observation in X (on host). - * @param[inout] centroids Cluster centers on device. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid. - * @param[out] n_iter Number of iterations run. + * @brief Find clusters with k-means algorithm using batched processing of host data. */ -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** - * @defgroup predict_batched Batched K-Means Predict + * @defgroup predict_host K-Means Predict (host data) * @{ */ @@ -306,86 +269,81 @@ void fit_batched(raft::resources const& handle, * * Streams data from host to GPU in batches, assigns each sample to its nearest * centroid, and writes labels back to host memory. + * The batch size is controlled by params.batched.batch_size. * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model. * @param[in] X Input samples on HOST memory. [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. * @param[in] sample_weight Optional weights for each observation (on host). * @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features] * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] * @param[in] normalize_weight Whether to normalize sample weights. * @param[out] inertia Sum of squared distances to nearest centroid. */ -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); /** * @brief Predict cluster labels for host data using batched processing (double). */ -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); /** * @} */ /** - * @defgroup fit_predict_batched Batched K-Means Fit + Predict + * @defgroup fit_predict_host K-Means Fit + Predict (host data) * @{ */ /** - * @brief Fit k-means and predict cluster labels using batched processing. + * @brief Fit k-means and predict cluster labels using batched processing of host data. * - * Combines fit_batched and predict_batched into a single call. + * Combines batched fit and batched predict into a single call. * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model. * @param[in] X Training instances on HOST memory. [dim = n_samples x n_features] - * @param[in] batch_size Number of samples to process per batch. * @param[in] sample_weight Optional weights for each observation (on host). * @param[inout] centroids Cluster centers on device. [dim = n_clusters x n_features] * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] * @param[out] inertia Sum of squared distances to nearest centroid. * @param[out] n_iter Number of iterations run. */ -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Fit k-means and predict cluster labels using batched processing (double). */ -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @} diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 921a70851d..c4e5fb278e 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -267,9 +267,8 @@ T compute_batched_host_inertia( * @tparam IdxT Index type (int, int64_t) * * @param[in] handle RAFT resources handle - * @param[in] params K-means parameters + * @param[in] params K-means parameters (batch size read from params.batched.batch_size) * @param[in] X Input data on HOST [n_samples x n_features] - * @param[in] batch_size Number of samples to process per GPU batch * @param[in] sample_weight Optional weights per sample (on host) * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] * @param[out] inertia Sum of squared distances to nearest centroid @@ -279,7 +278,6 @@ template void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, - IdxT batch_size, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, @@ -291,6 +289,10 @@ void fit(raft::resources const& handle, auto n_clusters = params.n_clusters; auto metric = params.metric; + // Read batch_size from params; default to n_samples if 0 (auto) + IdxT batch_size = static_cast(params.batched.batch_size); + if (batch_size <= 0) { batch_size = static_cast(n_samples); } + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, @@ -321,7 +323,6 @@ void fit(raft::resources const& handle, IdxT best_n_iter = 0; std::mt19937 gen(params.rng_state.seed); - bool compute_final_inertia = (n_init > 1) || params.batched.final_inertia_check; // ----- Allocate reusable work buffers (shared across n_init iterations) ----- auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); @@ -471,7 +472,8 @@ void fit(raft::resources const& handle, } } - if (compute_final_inertia) { + // Always compute final inertia (like regular kmeans) + { auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); inertia[0] = compute_batched_host_inertia( @@ -488,11 +490,6 @@ void fit(raft::resources const& handle, raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); raft::resource::sync_stream(handle, stream); } - } else { - inertia[0] = 0; - RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed (inertia computation skipped)", - seed_iter + 1, - n_init); } if (n_init > 1) { @@ -518,7 +515,6 @@ template void predict(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, - IdxT batch_size, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_vector_view labels, @@ -530,6 +526,10 @@ void predict(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = params.n_clusters; + // Read batch_size from params; default to n_samples if 0 (auto) + IdxT batch_size = static_cast(params.batched.batch_size); + if (batch_size <= 0) { batch_size = static_cast(n_samples); } + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), @@ -597,7 +597,6 @@ template void fit_predict(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::host_matrix_view X, - IdxT batch_size, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_vector_view labels, @@ -608,7 +607,6 @@ void fit_predict(raft::resources const& handle, fit(handle, params, X, - batch_size, sample_weight, centroids, raft::make_host_scalar_view(&fit_inertia), @@ -618,7 +616,7 @@ void fit_predict(raft::resources const& handle, centroids.data_handle(), centroids.extent(0), centroids.extent(1)); predict( - handle, params, X, batch_size, sample_weight, centroids_const, labels, false, inertia); + handle, params, X, sample_weight, centroids_const, labels, false, inertia); } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 4edd430b3b..91830716d5 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -40,30 +40,28 @@ INSTANTIATE_FIT(double, int64_t) #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter); } -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter); } void fit(raft::resources const& handle, @@ -90,32 +88,30 @@ void fit(raft::resources const& handle, handle, params, X, sample_weight, centroids, inertia, n_iter); } -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) { cuvs::cluster::kmeans::detail::predict( - handle, params, X, batch_size, sample_weight, centroids, labels, normalize_weight, inertia); + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit_predict( - handle, params, X, batch_size, sample_weight, centroids, labels, inertia, n_iter); + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index e55b543f77..732a34c214 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -40,30 +40,28 @@ INSTANTIATE_FIT(float, int64_t) #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter); } -void fit_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit( - handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter); } void fit(raft::resources const& handle, @@ -90,32 +88,30 @@ void fit(raft::resources const& handle, handle, params, X, sample_weight, centroids, inertia, n_iter); } -void predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) { cuvs::cluster::kmeans::detail::predict( - handle, params, X, batch_size, sample_weight, centroids, labels, normalize_weight, inertia); + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } -void fit_predict_batched(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - int64_t batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::detail::fit_predict( - handle, params, X, batch_size, sample_weight, centroids, labels, inertia, n_iter); + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 15f929e090..6c0b6e2c00 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -347,7 +347,7 @@ TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); // ============================================================================ -// Batched KMeans Tests (fit_batched + predict_batched) +// Batched KMeans Tests (fit + predict with host data) // ============================================================================ template @@ -382,8 +382,6 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_samples, n_features); auto labels = raft::make_device_vector(handle, n_samples); @@ -448,6 +446,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam> h_sw = std::nullopt; std::vector h_sample_weight; @@ -457,18 +456,16 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(h_sample_weight.data(), n_samples)); } - T inertia = 0; - int n_iter = 0; - int batch_size = std::min(n_samples, 256); + T inertia = 0; + int n_iter = 0; - cuvs::cluster::kmeans::fit_batched(handle, - batched_params, - h_X_view, - batch_size, - h_sw, - d_centroids_view, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); + cuvs::cluster::kmeans::fit(handle, + batched_params, + h_X_view, + h_sw, + d_centroids_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); raft::resource::sync_stream(handle, stream); @@ -619,14 +616,13 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam auto h_labels_view = raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); - T pred_inertia = 0; - int64_t batch_size = std::min((int64_t)n_samples, (int64_t)256); + T pred_inertia = 0; + params.batched.batch_size = std::min((int64_t)n_samples, (int64_t)256); - cuvs::cluster::kmeans::predict_batched( + cuvs::cluster::kmeans::predict( handle, params, h_X_view, - batch_size, std::optional>(std::nullopt), centroids_const_view, h_labels_view, @@ -641,7 +637,7 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam } raft::update_device(d_labels.data(), h_labels_int.data(), n_samples, stream); - // Compare labels directly: predict_batched should produce exact same labels + // Compare labels directly: batched predict should produce exact same labels // as device predict given the same centroids labels_match = devArrMatch(d_labels_ref.data(), d_labels.data(), n_samples, Compare(), stream); @@ -683,7 +679,7 @@ const std::vector> batched_inputsd2 = { }; // ============================================================================ -// fit_batched tests +// fit (host/batched) tests // ============================================================================ typedef KmeansFitBatchedTest KmeansFitBatchedTestF; typedef KmeansFitBatchedTest KmeansFitBatchedTestD; @@ -708,7 +704,7 @@ INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, ::testing::ValuesIn(batched_inputsd2)); // ============================================================================ -// predict_batched tests +// predict (host/batched) tests // ============================================================================ typedef KmeansPredictBatchedTest KmeansPredictBatchedTestF; typedef KmeansPredictBatchedTest KmeansPredictBatchedTestD; diff --git a/python/cuvs/cuvs/cluster/kmeans/__init__.py b/python/cuvs/cuvs/cluster/kmeans/__init__.py index 56c7645ebb..b547ec72f4 100644 --- a/python/cuvs/cuvs/cluster/kmeans/__init__.py +++ b/python/cuvs/cuvs/cluster/kmeans/__init__.py @@ -2,6 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -from .kmeans import KMeansParams, cluster_cost, fit, fit_batched, predict +from .kmeans import KMeansParams, cluster_cost, fit, predict -__all__ = ["KMeansParams", "cluster_cost", "fit", "fit_batched", "predict"] +__all__ = ["KMeansParams", "cluster_cost", "fit", "predict"] diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 40ac632b23..2d7294a4ef 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -33,7 +33,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: int batch_samples, int batch_centroids, bool inertia_check, - bool final_inertia_check, + int64_t batch_size, bool hierarchical, int hierarchical_n_iters @@ -49,7 +49,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: DLManagedTensor* sample_weight, DLManagedTensor * centroids, double * inertia, - int * n_iter) except + + int64_t * n_iter) except + cuvsError_t cuvsKMeansPredict(cuvsResources_t res, cuvsKMeansParams_t params, @@ -64,12 +64,3 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: DLManagedTensor* X, DLManagedTensor* centroids, double* cost) - - cuvsError_t cuvsKMeansFitBatched(cuvsResources_t res, - cuvsKMeansParams_t params, - DLManagedTensor* X, - int64_t batch_size, - DLManagedTensor* sample_weight, - DLManagedTensor* centroids, - double* inertia, - int64_t* n_iter) except + diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b8a28957fa..7b8d755dd1 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -79,11 +79,11 @@ cdef class KMeansParams: Number of centroids to process in each batch. If 0, uses n_clusters. inertia_check : bool If True, check inertia during iterations for early convergence. - final_inertia_check : bool - If True, compute the final inertia after fit_batched completes. - This requires an additional full pass over all the host data. - Only used by fit_batched(); regular fit() always computes final inertia. - Default: False (skip final inertia computation for performance). + batch_size : int + Number of samples to process per GPU batch when fitting with host + (numpy) data. When set to 0, defaults to n_samples (process all + at once). Only used by the batched (host-data) code path. + Default: 0 (auto). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -109,7 +109,7 @@ cdef class KMeansParams: batch_samples=None, batch_centroids=None, inertia_check=None, - final_inertia_check=None, + batch_size=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -133,8 +133,8 @@ cdef class KMeansParams: self.params.batch_centroids = batch_centroids if inertia_check is not None: self.params.inertia_check = inertia_check - if final_inertia_check is not None: - self.params.final_inertia_check = final_inertia_check + if batch_size is not None: + self.params.batch_size = batch_size if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -184,8 +184,8 @@ cdef class KMeansParams: return self.params.inertia_check @property - def final_inertia_check(self): - return self.params.final_inertia_check + def batch_size(self): + return self.params.batch_size @property def hierarchical(self): @@ -207,16 +207,26 @@ def fit( """ Find clusters with the k-means algorithm + When X is a device array (CUDA array interface), standard on-device + k-means is used. When X is a host array (numpy ndarray or + ``__array_interface__``), data is streamed to the GPU in batches + controlled by ``params.batch_size``. + Parameters ---------- params : KMeansParams - Parameters to use to fit KMeans model - X : Input CUDA array interface compliant matrix shape (m, k) + Parameters to use to fit KMeans model. For host data, + ``params.batch_size`` controls how many samples are sent to the + GPU per batch. + X : array-like + Training instances, shape (m, k). Accepts both device arrays + (cupy / CUDA array interface) and host arrays (numpy). centroids : Optional writable CUDA array interface compliant matrix shape (n_clusters, k) - sample_weights : Optional input CUDA array interface compliant matrix shape - (n_clusters, 1) default: None + sample_weights : Optional weights per observation. Must reside on + the same memory space as X (device or host). + default: None {resources_docstring} Returns @@ -244,10 +254,36 @@ def fit( >>> params = KMeansParams(n_clusters=n_clusters) >>> centroids, inertia, n_iter = fit(params, X) + + Host-data (batched) example: + + >>> import numpy as np + >>> X_host = np.random.random((10_000_000, 128)).astype(np.float32) + >>> params = KMeansParams(n_clusters=1000, batch_size=1_000_000) + >>> centroids, inertia, n_iter = fit(params, X_host) """ + # ---- detect host vs device data for data-preparation ---- + is_host = isinstance(X, np.ndarray) or ( + hasattr(X, '__array_interface__') and + not hasattr(X, '__cuda_array_interface__') + ) + + if is_host: + if not isinstance(X, np.ndarray): + X = np.asarray(X) + if not X.flags['C_CONTIGUOUS']: + X = np.ascontiguousarray(X) + if sample_weights is not None: + if not isinstance(sample_weights, np.ndarray): + sample_weights = np.asarray(sample_weights) + if not sample_weights.flags['C_CONTIGUOUS']: + sample_weights = np.ascontiguousarray(sample_weights) + x_ai = wrap_array(X) - _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + _check_input_array( + x_ai, [np.dtype('float32'), np.dtype('float64')] + ) cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) cdef cydlpack.DLManagedTensor* sample_weight_dlpack = NULL @@ -255,18 +291,20 @@ def fit( cdef cuvsResources_t res = resources.get_c_obj() cdef double inertia = 0 - cdef int n_iter = 0 + cdef int64_t n_iter = 0 if centroids is None: - centroids = device_ndarray.empty((params.n_clusters, x_ai.shape[1]), - dtype=x_ai.dtype) + centroids = device_ndarray.empty( + (params.n_clusters, x_ai.shape[1]), dtype=x_ai.dtype + ) centroids_ai = wrap_array(centroids) - cdef cydlpack.DLManagedTensor * centroids_dlpack = \ + cdef cydlpack.DLManagedTensor* centroids_dlpack = \ cydlpack.dlpack_c(centroids_ai) if sample_weights is not None: - sample_weight_dlpack = cydlpack.dlpack_c(wrap_array(sample_weights)) + sample_weight_dlpack = \ + cydlpack.dlpack_c(wrap_array(sample_weights)) with cuda_interruptible(): check_cuvs(cuvsKMeansFit( @@ -436,108 +474,3 @@ def cluster_cost(X, centroids, resources=None): return inertia -@auto_sync_resources -@auto_convert_output -def fit_batched( - KMeansParams params, X, batch_size, centroids=None, sample_weights=None, - resources=None -): - """ - Find clusters with the k-means algorithm using batched processing. - - This function processes data from HOST memory in batches, streaming - to the GPU. Useful when the dataset is too large to fit in GPU memory. - - Parameters - ---------- - - params : KMeansParams - Parameters to use to fit KMeans model - X : numpy array or array with __array_interface__ - Input HOST memory array shape (n_samples, n_features). - Must be C-contiguous. Supported dtypes: float32, float64. - batch_size : int - Number of samples to process per batch. Recommended: 500K-2M - depending on GPU memory. - centroids : Optional writable CUDA array interface compliant matrix - shape (n_clusters, n_features) - sample_weights : Optional input HOST memory array shape (n_samples,) - default: None - {resources_docstring} - - Returns - ------- - centroids : raft.device_ndarray - The computed centroids for each cluster (on device) - inertia : float - Sum of squared distances of samples to their closest cluster center - n_iter : int - The number of iterations used to fit the model - - Examples - -------- - - >>> import numpy as np - >>> import cupy as cp - >>> - >>> from cuvs.cluster.kmeans import fit_batched, KMeansParams - >>> - >>> n_samples = 10_000_000 - >>> n_features = 128 - >>> n_clusters = 1000 - >>> - >>> # Data on host (numpy array) - >>> X = np.random.random((n_samples, n_features)).astype(np.float32) - >>> - >>> params = KMeansParams(n_clusters=n_clusters, max_iter=20) - >>> centroids, inertia, n_iter = fit_batched(params, X, batch_size=1_000_000) - """ - # Ensure X is a numpy array (host memory) - if not isinstance(X, np.ndarray): - X = np.asarray(X) - - if not X.flags['C_CONTIGUOUS']: - X = np.ascontiguousarray(X) - - _check_input_array(wrap_array(X), [np.dtype('float32'), np.dtype('float64')]) - - cdef int64_t n_samples = X.shape[0] - cdef int64_t n_features = X.shape[1] - - # Create DLPack tensor for host data - cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(wrap_array(X)) - cdef cydlpack.DLManagedTensor* sample_weight_dlpack = NULL - - cdef cuvsResources_t res = resources.get_c_obj() - - cdef double inertia = 0 - cdef int64_t n_iter = 0 - cdef int64_t c_batch_size = batch_size - - if centroids is None: - centroids = device_ndarray.empty((params.n_clusters, n_features), - dtype=X.dtype) - - centroids_ai = wrap_array(centroids) - cdef cydlpack.DLManagedTensor* centroids_dlpack = \ - cydlpack.dlpack_c(centroids_ai) - - if sample_weights is not None: - if not isinstance(sample_weights, np.ndarray): - sample_weights = np.asarray(sample_weights) - if not sample_weights.flags['C_CONTIGUOUS']: - sample_weights = np.ascontiguousarray(sample_weights) - sample_weight_dlpack = cydlpack.dlpack_c(wrap_array(sample_weights)) - - with cuda_interruptible(): - check_cuvs(cuvsKMeansFitBatched( - res, - params.params, - x_dlpack, - c_batch_size, - sample_weight_dlpack, - centroids_dlpack, - &inertia, - &n_iter)) - - return FitOutput(centroids, inertia, n_iter) From d629ca8c0927c3814a3959821875140d9420cffd Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 17:13:25 -0800 Subject: [PATCH 57/81] fix python tests, style --- cpp/tests/cluster/kmeans.cu | 2 +- python/cuvs/cuvs/cluster/kmeans/__init__.py | 2 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 2 -- python/cuvs/cuvs/tests/test_kmeans.py | 25 ++++++++++++--------- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 6c0b6e2c00..e1bf7a8ea5 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -616,7 +616,7 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam auto h_labels_view = raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); - T pred_inertia = 0; + T pred_inertia = 0; params.batched.batch_size = std::min((int64_t)n_samples, (int64_t)256); cuvs::cluster::kmeans::predict( diff --git a/python/cuvs/cuvs/cluster/kmeans/__init__.py b/python/cuvs/cuvs/cluster/kmeans/__init__.py index b547ec72f4..f4765bcb64 100644 --- a/python/cuvs/cuvs/cluster/kmeans/__init__.py +++ b/python/cuvs/cuvs/cluster/kmeans/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 7b8d755dd1..e532c78da1 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -472,5 +472,3 @@ def cluster_cost(X, centroids, resources=None): &inertia)) return inertia - - diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index c464666901..d614cb69ca 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -10,7 +10,6 @@ KMeansParams, cluster_cost, fit, - fit_batched, predict, ) from cuvs.distance import pairwise_distance @@ -82,12 +81,12 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): @pytest.mark.parametrize("n_clusters", [8, 16]) @pytest.mark.parametrize("batch_size", [100, 500]) @pytest.mark.parametrize("dtype", [np.float64]) -def test_fit_batched_matches_fit( +def test_fit_host_matches_fit_device( n_rows, n_cols, n_clusters, batch_size, dtype ): """ - Test that fit_batched FullBatch produces the same centroids as regular fit - when given the same initial centroids. + Test that fit() with host (numpy) data produces the same centroids as + fit() with device data, when given the same initial centroids. """ rng = np.random.default_rng(99) X_host = rng.random((n_rows, n_cols)).astype(dtype) @@ -98,23 +97,29 @@ def test_fit_batched_matches_fit( initial_centroids_host = X_host[:n_clusters].copy() - params = KMeansParams( + params_device = KMeansParams( n_clusters=n_clusters, init_method="Array", max_iter=100, tol=1e-10, ) centroids_regular, _, _ = fit( - params, + params_device, device_ndarray(X_host), device_ndarray(initial_centroids_host.copy()), ) centroids_regular = centroids_regular.copy_to_host() - centroids_batched, _, _ = fit_batched( - params, - X_host, + params_host = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=100, + tol=1e-10, batch_size=batch_size, + ) + centroids_batched, _, _ = fit( + params_host, + X_host, centroids=device_ndarray(initial_centroids_host.copy()), ) centroids_batched = centroids_batched.copy_to_host() @@ -122,5 +127,3 @@ def test_fit_batched_matches_fit( assert np.allclose( centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3 ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" - - From c8ac47717dbcd6d7585ad2ce2c041de37ba450ab Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 17:18:41 -0800 Subject: [PATCH 58/81] fix style --- cpp/include/cuvs/cluster/kmeans.hpp | 1 - cpp/src/cluster/detail/kmeans_batched.cuh | 12 +++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index b07d02ace4..703281d187 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -50,7 +50,6 @@ struct params : base_params { Array }; - /** * The number of clusters to form as well as the number of centroids to generate (default:8). */ diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index c4e5fb278e..bb30c7952a 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -604,19 +604,13 @@ void fit_predict(raft::resources const& handle, raft::host_scalar_view n_iter) { T fit_inertia = 0; - fit(handle, - params, - X, - sample_weight, - centroids, - raft::make_host_scalar_view(&fit_inertia), - n_iter); + fit( + handle, params, X, sample_weight, centroids, raft::make_host_scalar_view(&fit_inertia), n_iter); auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - predict( - handle, params, X, sample_weight, centroids_const, labels, false, inertia); + predict(handle, params, X, sample_weight, centroids_const, labels, false, inertia); } } // namespace cuvs::cluster::kmeans::detail From 13b4084896df4fccb3edb94892adb0e004877f9d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 17:27:19 -0800 Subject: [PATCH 59/81] rm extra c helpers --- c/src/cluster/kmeans.cpp | 194 +++++++++++++++++++-------------------- 1 file changed, 93 insertions(+), 101 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index a6c912af61..e73a2966b6 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -40,55 +40,112 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP return kmeans_params; } -template -void _fit_device(cuvsResources_t res, - const cuvsKMeansParams& params, - DLManagedTensor* X_tensor, - DLManagedTensor* sample_weight_tensor, - DLManagedTensor* centroids_tensor, - double* inertia, - int64_t* n_iter) +template +void _fit(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + int64_t* n_iter) { + auto X = X_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); + bool is_host = (X.device.device_type == kDLCPU); - using const_mdspan_type = raft::device_matrix_view; - using mdspan_type = raft::device_matrix_view; + if (is_host) { + // ---- host / batched path (IdxT = int64_t) ---- + using IdxT = int64_t; + auto n_samples = static_cast(X.shape[0]); + auto n_features = static_cast(X.shape[1]); - if (params.hierarchical) { - if (sample_weight_tensor != NULL) { - RAFT_FAIL("sample_weight cannot be used with hierarchical kmeans"); + if (params.hierarchical) { + RAFT_FAIL("hierarchical kmeans is not supported with host data"); } - if constexpr (std::is_same_v) { - RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans"); - } else { - auto kmeans_params = convert_balanced_params(params); - T inertia_temp; - auto inertia_view = raft::make_host_scalar_view(&inertia_temp); - cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, cuvs::core::from_dlpack(X_tensor), cuvs::core::from_dlpack(centroids_tensor), std::make_optional(inertia_view)); - *inertia = inertia_temp; - *n_iter = params.hierarchical_n_iters; + auto centroids_dl = centroids_tensor->dl_tensor; + if (!cuvs::core::is_dlpack_device_compatible(centroids_dl)) { + RAFT_FAIL("centroids must be on device memory"); } - } else { - T inertia_temp; - IdxT n_iter_temp; - std::optional> sample_weight; + auto X_view = raft::make_host_matrix_view( + reinterpret_cast(X.data), n_samples, n_features); + auto centroids_view = + cuvs::core::from_dlpack>( + centroids_tensor); + + std::optional> sample_weight; if (sample_weight_tensor != NULL) { - sample_weight = - cuvs::core::from_dlpack>(sample_weight_tensor); + auto sw = sample_weight_tensor->dl_tensor; + if (sw.device.device_type != kDLCPU) { + RAFT_FAIL("sample_weight must be on host memory when X is on host"); + } + sample_weight = raft::make_host_vector_view( + reinterpret_cast(sw.data), n_samples); } + T inertia_temp; + IdxT n_iter_temp; + auto kmeans_params = convert_params(params); cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, - cuvs::core::from_dlpack(X_tensor), + X_view, sample_weight, - cuvs::core::from_dlpack(centroids_tensor), - raft::make_host_scalar_view(&inertia_temp), - raft::make_host_scalar_view(&n_iter_temp)); + centroids_view, + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + *inertia = inertia_temp; - *n_iter = static_cast(n_iter_temp); + *n_iter = n_iter_temp; + + } else { + // ---- device path (IdxT = int32_t) ---- + using IdxT = int32_t; + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; + + if (params.hierarchical) { + if (sample_weight_tensor != NULL) { + RAFT_FAIL("sample_weight cannot be used with hierarchical kmeans"); + } + + if constexpr (std::is_same_v) { + RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans"); + } else { + auto kmeans_params = convert_balanced_params(params); + T inertia_temp; + auto inertia_view = raft::make_host_scalar_view(&inertia_temp); + cuvs::cluster::kmeans::fit( + *res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + cuvs::core::from_dlpack(centroids_tensor), + std::make_optional(inertia_view)); + *inertia = inertia_temp; + *n_iter = params.hierarchical_n_iters; + } + } else { + T inertia_temp; + IdxT n_iter_temp; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + auto kmeans_params = convert_params(params); + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + *inertia = inertia_temp; + *n_iter = static_cast(n_iter_temp); + } } } @@ -172,63 +229,6 @@ void _cluster_cost(cuvsResources_t res, *cost = cost_temp; } - -template -void _fit_host(cuvsResources_t res, - const cuvsKMeansParams& params, - DLManagedTensor* X_tensor, - DLManagedTensor* sample_weight_tensor, - DLManagedTensor* centroids_tensor, - double* inertia, - IdxT* n_iter) -{ - auto X = X_tensor->dl_tensor; - auto centroids = centroids_tensor->dl_tensor; - auto res_ptr = reinterpret_cast(res); - auto n_samples = static_cast(X.shape[0]); - auto n_features = static_cast(X.shape[1]); - - // X must be on host (CPU) memory - if (X.device.device_type != kDLCPU) { - RAFT_FAIL("X dataset must be on host (CPU) memory for batched fit"); - } - - // centroids must be on device memory - if (!cuvs::core::is_dlpack_device_compatible(centroids)) { - RAFT_FAIL("centroids must be on device memory"); - } - - // Create host matrix view from X - auto X_view = raft::make_host_matrix_view( - reinterpret_cast(X.data), n_samples, n_features); - - auto centroids_view = cuvs::core::from_dlpack>(centroids_tensor); - - std::optional> sample_weight; - if (sample_weight_tensor != NULL) { - auto sw = sample_weight_tensor->dl_tensor; - if (sw.device.device_type != kDLCPU) { - RAFT_FAIL("sample_weight must be on host (CPU) memory for batched fit"); - } - sample_weight = raft::make_host_vector_view( - reinterpret_cast(sw.data), n_samples); - } - - T inertia_temp; - IdxT n_iter_temp; - - auto kmeans_params = convert_params(params); - cuvs::cluster::kmeans::fit(*res_ptr, - kmeans_params, - X_view, - sample_weight, - centroids_view, - raft::make_host_scalar_view(&inertia_temp), - raft::make_host_scalar_view(&n_iter_temp)); - - *inertia = inertia_temp; - *n_iter = n_iter_temp; -} } // namespace extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) @@ -267,19 +267,11 @@ extern "C" cuvsError_t cuvsKMeansFit(cuvsResources_t res, int64_t* n_iter) { return cuvs::core::translate_exceptions([=] { - auto dataset = X->dl_tensor; - bool is_host = (dataset.device.device_type == kDLCPU); - + auto dataset = X->dl_tensor; if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { - if (is_host) - _fit_host(res, *params, X, sample_weight, centroids, inertia, n_iter); - else - _fit_device(res, *params, X, sample_weight, centroids, inertia, n_iter); + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { - if (is_host) - _fit_host(res, *params, X, sample_weight, centroids, inertia, n_iter); - else - _fit_device(res, *params, X, sample_weight, centroids, inertia, n_iter); + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, From 0bb59a911cd930f4e23c487f7aa89c64285835e6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 5 Mar 2026 17:45:23 -0800 Subject: [PATCH 60/81] add eof --- c/src/cluster/kmeans.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index e73a2966b6..356df657e8 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -322,4 +322,3 @@ extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, } }); } - From fa7715123ea4705849cd8699b808e9a69d32234c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 6 Mar 2026 11:42:15 -0800 Subject: [PATCH 61/81] fix docs --- cpp/include/cuvs/cluster/kmeans.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 703281d187..52c3e96fa3 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -114,6 +114,7 @@ struct params : base_params { bool inertia_check = false; /** + * @struct batched_params * Parameters specific to batched k-means (host-data overloads of fit/predict). * These parameters are only used when calling fit() or predict() with host_matrix_view data * and are ignored by the device_matrix_view overloads. From 068d66f665c0743a5347e99ea12288978ec313fe Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 14:06:09 -0700 Subject: [PATCH 62/81] address pr reviews; update inertia comp --- cpp/src/cluster/detail/kmeans_batched.cuh | 97 ++++++----------------- 1 file changed, 23 insertions(+), 74 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index bb30c7952a..ca75a82b3c 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -38,64 +40,7 @@ namespace cuvs::cluster::kmeans::detail { /** - * @brief Sample data from host to device for initialization/validation. - * - * When sample weights are provided, the corresponding weights are gathered - * alongside the sampled rows and copied to device - * - * @tparam T Input data type - * @tparam IdxT Index type - */ -template -void prepare_init_sample( - raft::resources const& handle, - raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed, - std::optional> weight_in = std::nullopt, - std::optional> weight_out = std::nullopt) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_samples_out = X_sample.extent(0); - - std::mt19937 gen(seed); - std::uniform_int_distribution dist(0, n_samples - 1); - - // Generate n_samples_out unique random indices using rejection sampling - // Since n_samples_out << n_samples, collisions are rare and this is O(n_samples_out) - std::unordered_set selected_indices; - selected_indices.reserve(n_samples_out); - - while (static_cast(selected_indices.size()) < n_samples_out) { - selected_indices.insert(dist(gen)); - } - - std::vector indices(selected_indices.begin(), selected_indices.end()); - - bool copy_weights = weight_in.has_value() && weight_out.has_value(); - std::vector host_weights; - if (copy_weights) { host_weights.resize(n_samples_out); } - - std::vector host_sample(n_samples_out * n_features); -#pragma omp parallel for - for (IdxT i = 0; i < static_cast(n_samples_out); i++) { - IdxT src_idx = indices[i]; - std::memcpy(host_sample.data() + i * n_features, - X.data_handle() + src_idx * n_features, - n_features * sizeof(T)); - if (copy_weights) { host_weights[i] = weight_in->data_handle()[src_idx]; } - } - - raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); - if (copy_weights) { - raft::copy(weight_out->data_handle(), host_weights.data(), n_samples_out, stream); - } -} - -/** - * @brief Initialize centroids using k-means++ on a sample of the host data + * @brief Initialize centroids from host data * * @tparam T Input data type * @tparam IdxT Index type @@ -113,14 +58,15 @@ void init_centroids_from_host_sample(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - IdxT init_sample_size = 3 * batch_size; - if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } - init_sample_size = std::min(init_sample_size, n_samples); + if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IdxT init_sample_size = 3 * batch_size; + if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } + init_sample_size = std::min(init_sample_size, n_samples); - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + raft::random::RngState random_state(params.rng_state.seed); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { auto init_sample_view = raft::make_device_matrix_view( init_sample.data_handle(), init_sample_size, n_features); @@ -132,7 +78,8 @@ void init_centroids_from_host_sample(raft::resources const& handle, handle, params, init_sample_view, centroids, workspace); } } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - raft::copy(centroids.data_handle(), init_sample.data_handle(), n_clusters * n_features, stream); + raft::random::RngState random_state(params.rng_state.seed); + raft::matrix::sample_rows(handle, random_state, X, centroids); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { // already provided } else { @@ -365,8 +312,7 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); + auto centroids_const = raft::make_const_mdspan(centroids); using namespace cuvs::spatial::knn::detail::utils; batch_load_iterator data_batches( @@ -423,6 +369,7 @@ void fit(raft::resources const& handle, cluster_counts.view()); if (params.inertia_check) { + // Compute cluster cost for this batch and accumulate cuvs::cluster::kmeans::detail::computeClusterCost( handle, minClusterAndDistance.view(), @@ -472,12 +419,15 @@ void fit(raft::resources const& handle, } } - // Always compute final inertia (like regular kmeans) { - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - inertia[0] = compute_batched_host_inertia( - handle, X, batch_size, centroids_const, sample_weight); + // If inertia_check was enabled, we already computed inertia during iterations + if (params.inertia_check) { + inertia[0] = priorClusteringCost; + } else { + auto centroids_const = raft::make_const_mdspan(centroids); + inertia[0] = compute_batched_host_inertia( + handle, X, batch_size, centroids_const, sample_weight); + } RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", seed_iter + 1, @@ -607,8 +557,7 @@ void fit_predict(raft::resources const& handle, fit( handle, params, X, sample_weight, centroids, raft::make_host_scalar_view(&fit_inertia), n_iter); - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + auto centroids_const = raft::make_const_mdspan(centroids); predict(handle, params, X, sample_weight, centroids_const, labels, false, inertia); } From 8e0be372272abdbabbaf5a7795b0b8bbf296115d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 14:18:15 -0700 Subject: [PATCH 63/81] revert abi change --- c/include/cuvs/cluster/kmeans.h | 2 +- c/src/cluster/kmeans.cpp | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 62ab09ec59..5057d1a86b 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -179,7 +179,7 @@ cuvsError_t cuvsKMeansFit(cuvsResources_t res, DLManagedTensor* sample_weight, DLManagedTensor* centroids, double* inertia, - int64_t* n_iter); + int* n_iter); /** * @brief Predict the closest cluster each sample in X belongs to. diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 356df657e8..89fe912048 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -40,14 +40,14 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP return kmeans_params; } -template +template void _fit(cuvsResources_t res, const cuvsKMeansParams& params, DLManagedTensor* X_tensor, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, double* inertia, - int64_t* n_iter) + int* n_iter) { auto X = X_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); @@ -100,8 +100,6 @@ void _fit(cuvsResources_t res, *n_iter = n_iter_temp; } else { - // ---- device path (IdxT = int32_t) ---- - using IdxT = int32_t; using const_mdspan_type = raft::device_matrix_view; using mdspan_type = raft::device_matrix_view; @@ -144,7 +142,7 @@ void _fit(cuvsResources_t res, raft::make_host_scalar_view(&inertia_temp), raft::make_host_scalar_view(&n_iter_temp)); *inertia = inertia_temp; - *n_iter = static_cast(n_iter_temp); + *n_iter = n_iter_temp; } } } @@ -264,7 +262,7 @@ extern "C" cuvsError_t cuvsKMeansFit(cuvsResources_t res, DLManagedTensor* sample_weight, DLManagedTensor* centroids, double* inertia, - int64_t* n_iter) + int* n_iter) { return cuvs::core::translate_exceptions([=] { auto dataset = X->dl_tensor; From 0a7f0261d3780c08f8ca65831d68ddcb52bc230c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 14:35:39 -0700 Subject: [PATCH 64/81] rm null dataset norm --- cpp/src/cluster/detail/kmeans_balanced.cuh | 11 +++++------ cpp/src/cluster/kmeans_balanced.cuh | 1 - 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index cab01221e7..5f2164f71a 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1012,7 +1012,6 @@ auto build_fine_clusters(const raft::resources& handle, * @param[out] cluster_centers A device pointer to the found cluster centers [n_clusters, dim] * @param[in] n_clusters Requested number of clusters * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] dataset_norm (optional) Pre-computed L2 norms of each row in the dataset [n_rows] * @param[out] inertia (optional) If non-null, the sum of squared distances of samples to * their closest cluster center is written here. * Only supported when T == MathT (float/double). @@ -1026,8 +1025,7 @@ void build_hierarchical(const raft::resources& handle, MathT* cluster_centers, IdxT n_clusters, MappingOpT mapping_op, - const MathT* dataset_norm = nullptr, - MathT* inertia = nullptr) + MathT* inertia = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -1046,9 +1044,10 @@ void build_hierarchical(const raft::resources& handle, // Precompute the L2 norm of the dataset if relevant and not yet computed. rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded)) { + const MathT* dataset_norm = nullptr; + if ((params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded)) { dataset_norm_buf.resize(n_rows, stream); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 2b250b92cf..0c0df03397 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -91,7 +91,6 @@ void fit(const raft::resources& handle, centroids.data_handle(), centroids.extent(0), mapping_op, - static_cast(nullptr), inertia_ptr); } From 10be5c41650bdd64eceff14b47b3b9426c9cc0f4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 14:40:45 -0700 Subject: [PATCH 65/81] add warning when T and MathT are different --- cpp/src/cluster/detail/kmeans_balanced.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 5f2164f71a..f5dc759725 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1175,6 +1175,8 @@ void build_hierarchical(const raft::resources& handle, raft::make_device_matrix_view(cluster_centers, n_clusters, dim); cuvs::cluster::kmeans::cluster_cost( handle, X_view, centroids_view, raft::make_host_scalar_view(inertia)); + } else { + RAFT_LOG_WARN("Inertia is not computed for non float/double types"); } } } From 6a2a681c3fc48c170102be776f4c5f9ca3b35220 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 14:54:12 -0700 Subject: [PATCH 66/81] use raft::mul_op --- cpp/src/cluster/kmeans.cuh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 8e26372b46..e4f9821990 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -455,12 +455,11 @@ void cluster_cost( // Apply sample weights if provided if (sample_weight.has_value()) { - raft::linalg::map( - handle, - min_cluster_distance.view(), - [] __device__(DataT d, DataT w) { return d * w; }, - raft::make_const_mdspan(min_cluster_distance.view()), - sample_weight.value()); + raft::linalg::map(handle, + min_cluster_distance.view(), + raft::mul_op{}, + raft::make_const_mdspan(min_cluster_distance.view()), + sample_weight.value()); } auto device_cost = raft::make_device_scalar(handle, DataT(0)); From e7a9b3a7f5120c603a62790a0dce51ca0e991348 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 15:13:14 -0700 Subject: [PATCH 67/81] put batch size at the end of the c header struct --- c/include/cuvs/cluster/kmeans.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 5057d1a86b..7296c1109f 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -94,12 +94,6 @@ struct cuvsKMeansParams { /** Check inertia during iterations for early convergence. */ bool inertia_check; - /** - * Number of samples to process per GPU batch for the batched (host-data) API. - * When set to 0, defaults to n_samples (process all at once). - */ - int64_t batch_size; - /** * Whether to use hierarchical (balanced) kmeans or not */ @@ -109,6 +103,12 @@ struct cuvsKMeansParams { * For hierarchical k-means , defines the number of training iterations */ int hierarchical_n_iters; + + /** + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). + */ + int64_t batch_size; }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; From 6439d72293ab53b77f7975d211afe441d2c99025 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 16:44:43 -0700 Subject: [PATCH 68/81] fix c and python compilation --- c/src/cluster/kmeans.cpp | 6 ++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 2 +- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 89fe912048..8997ef8bb4 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -54,8 +54,6 @@ void _fit(cuvsResources_t res, bool is_host = (X.device.device_type == kDLCPU); if (is_host) { - // ---- host / batched path (IdxT = int64_t) ---- - using IdxT = int64_t; auto n_samples = static_cast(X.shape[0]); auto n_features = static_cast(X.shape[1]); @@ -245,9 +243,9 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, .inertia_check = cpp_params.inertia_check, - .batch_size = cpp_params.batched.batch_size, .hierarchical = false, - .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; + .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), + .batch_size = cpp_params.batched.batch_size}; }); } diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 2d7294a4ef..3b90526417 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -49,7 +49,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: DLManagedTensor* sample_weight, DLManagedTensor * centroids, double * inertia, - int64_t * n_iter) except + + int * n_iter) except + cuvsError_t cuvsKMeansPredict(cuvsResources_t res, cuvsKMeansParams_t params, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index e532c78da1..c594199511 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -291,7 +291,7 @@ def fit( cdef cuvsResources_t res = resources.get_c_obj() cdef double inertia = 0 - cdef int64_t n_iter = 0 + cdef int n_iter = 0 if centroids is None: centroids = device_ndarray.empty( From 6a094cbd56ecfabee4b46ac15cc8305f18f7b93b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 16:47:52 -0700 Subject: [PATCH 69/81] add docs --- cpp/include/cuvs/cluster/kmeans.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 52c3e96fa3..f7a4e1db0b 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -100,6 +100,10 @@ struct params : base_params { * useful to optimize/control the memory footprint * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 * then don't tile the centroids + * + * NB: These parameters are unrelated to batched_params.batch_size, which + * controls how many samples to transfer from host to device per batch when + * processing out-of-core data. */ int batch_samples = 1 << 15; From 2d14e9a563dcdf5420692d5c23a2d49263ab009b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 9 Mar 2026 17:49:25 -0700 Subject: [PATCH 70/81] style --- cpp/include/cuvs/cluster/kmeans.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index f7a4e1db0b..4e41ef2a52 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -100,7 +100,7 @@ struct params : base_params { * useful to optimize/control the memory footprint * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 * then don't tile the centroids - * + * * NB: These parameters are unrelated to batched_params.batch_size, which * controls how many samples to transfer from host to device per batch when * processing out-of-core data. From 37ce40434ff1d1d6200e438a932e6f4222dab322 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 12 Mar 2026 10:04:34 -0700 Subject: [PATCH 71/81] correct treatment for optional --- cpp/src/cluster/detail/kmeans_batched.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index ca75a82b3c..f72c6a9454 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -181,7 +181,7 @@ T compute_batched_host_inertia( // Build optional device weight view for this batch std::optional> batch_weight_view; - if (sample_weight) { + if (sample_weight.has_value()) { auto weight_offset = static_cast(data_batch.offset()); raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + weight_offset, @@ -326,7 +326,7 @@ void fit(raft::resources const& handle, auto batch_weights_fill_view = raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); - if (sample_weight) { + if (sample_weight.has_value()) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + data_batch.offset(), current_batch_size, @@ -501,7 +501,7 @@ void predict(raft::resources const& handle, current_batch_size * n_features, stream); - if (sample_weight) { + if (sample_weight.has_value()) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + batch_idx, current_batch_size, @@ -512,7 +512,7 @@ void predict(raft::resources const& handle, batch_data.data_handle(), current_batch_size, n_features); std::optional> batch_weights_view = std::nullopt; - if (sample_weight) { + if (sample_weight.has_value()) { batch_weights_view = std::make_optional(raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size)); } From 1b3c3415eed4eff32f9fc3d87bc389684403e65c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 12 Mar 2026 10:21:00 -0700 Subject: [PATCH 72/81] fill outside loop --- cpp/src/cluster/detail/kmeans_batched.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index f72c6a9454..99a25454fb 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -299,6 +299,8 @@ void fit(raft::resources const& handle, init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); } + if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } + // Reset per-iteration state T priorClusteringCost = 0; @@ -331,8 +333,6 @@ void fit(raft::resources const& handle, sample_weight->data_handle() + data_batch.offset(), current_batch_size, stream); - } else { - raft::matrix::fill(handle, batch_weights_fill_view, T{1}); } auto batch_weights_view = raft::make_device_vector_view( From 914628cf237e09763d53b7fc676d0533db2ef367 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 12 Mar 2026 10:29:14 -0700 Subject: [PATCH 73/81] add warning --- cpp/src/cluster/detail/kmeans_batched.cuh | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 99a25454fb..56c00a6e0e 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -241,6 +241,17 @@ void fit(raft::resources const& handle, if (batch_size <= 0) { batch_size = static_cast(n_samples); } RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + + // Warn if user explicitly set batch_size larger than dataset size + if (params.batched.batch_size > 0 && static_cast(params.batched.batch_size) > n_samples) { + RAFT_LOG_WARN( + "batch_size (%zu) is larger than dataset size (%zu). " + "batch_size will be effectively clamped to %zu.", + static_cast(params.batched.batch_size), + static_cast(n_samples), + static_cast(n_samples)); + } + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "centroids.extent(0) must equal n_clusters"); @@ -481,6 +492,17 @@ void predict(raft::resources const& handle, if (batch_size <= 0) { batch_size = static_cast(n_samples); } RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + + // Warn if user explicitly set batch_size larger than dataset size + if (params.batched.batch_size > 0 && static_cast(params.batched.batch_size) > n_samples) { + RAFT_LOG_WARN( + "batch_size (%zu) is larger than dataset size (%zu). " + "batch_size will be effectively clamped to %zu.", + static_cast(params.batched.batch_size), + static_cast(n_samples), + static_cast(n_samples)); + } + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), "centroids.extent(0) must equal n_clusters"); From f1d3f8a34c70fc848546ab90be15954b9f12df49 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 12 Mar 2026 10:36:20 -0700 Subject: [PATCH 74/81] python docs --- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index c594199511..1ef294051f 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -82,8 +82,12 @@ cdef class KMeansParams: batch_size : int Number of samples to process per GPU batch when fitting with host (numpy) data. When set to 0, defaults to n_samples (process all - at once). Only used by the batched (host-data) code path. - Default: 0 (auto). + at once). Only used by the batched (host-data) code path. Reducing + batch_size can help reduce GPU memory pressure but increases + overhead as the number of times centroid adjustments are computed + increases. + + Default: 0 (process all data at once). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -210,7 +214,8 @@ def fit( When X is a device array (CUDA array interface), standard on-device k-means is used. When X is a host array (numpy ndarray or ``__array_interface__``), data is streamed to the GPU in batches - controlled by ``params.batch_size``. + controlled by ``params.batch_size``. For large host datasets, consider + reducing ``batch_size`` to reduce GPU memory usage. Parameters ---------- From 552f7360f5d8ba2d5f4956c5cd87c652888834e6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 12 Mar 2026 13:40:06 -0700 Subject: [PATCH 75/81] fix compilation warning --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 56c00a6e0e..8df3f10093 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -337,8 +337,6 @@ void fit(raft::resources const& handle, auto batch_data_view = raft::make_device_matrix_view( data_batch.data(), current_batch_size, n_features); - auto batch_weights_fill_view = - raft::make_device_vector_view(batch_weights.data_handle(), current_batch_size); if (sample_weight.has_value()) { raft::copy(batch_weights.data_handle(), sample_weight->data_handle() + data_batch.offset(), From 7f6e61576eeb2a9ad2e84d598339625af31813cf Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 12:42:57 -0700 Subject: [PATCH 76/81] optimizations and cleanups --- cpp/src/cluster/detail/kmeans_batched.cuh | 111 ++++------------------ cpp/src/cluster/detail/kmeans_common.cuh | 6 +- 2 files changed, 20 insertions(+), 97 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 8df3f10093..4921f644f2 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -142,68 +142,6 @@ void accumulate_batch_centroids( stream); } -/** - * @brief Compute total inertia over host data using batched GPU processing. - * - * Iterates over the host data in batches, computing the (optionally weighted) - * sum of squared distances from each sample to its nearest centroid. - * - * @param[in] sample_weight Optional per-sample weights on host. When provided, - * each squared distance is multiplied by its weight - * before summing (matching sklearn's weighted inertia). - */ -template -T compute_batched_host_inertia( - raft::resources const& handle, - raft::host_matrix_view X, - IdxT batch_size, - raft::device_matrix_view centroids, - std::optional> sample_weight = std::nullopt) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - - IdxT effective_batch = std::min(batch_size, static_cast(n_samples)); - - // Device buffer for per-batch weights (only used when sample_weight is provided) - auto batch_weights = - raft::make_device_vector(handle, sample_weight ? effective_batch : IdxT{0}); - - T total_inertia = 0; - using namespace cuvs::spatial::knn::detail::utils; - batch_load_iterator data_batches(X.data_handle(), n_samples, n_features, batch_size, stream); - - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - auto batch_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - // Build optional device weight view for this batch - std::optional> batch_weight_view; - if (sample_weight.has_value()) { - auto weight_offset = static_cast(data_batch.offset()); - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + weight_offset, - current_batch_size, - stream); - batch_weight_view = raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size); - } - - T batch_cost; - cuvs::cluster::kmeans::cluster_cost(handle, - batch_view, - centroids, - raft::make_host_scalar_view(&batch_cost), - batch_weight_view); - - total_inertia += batch_cost; - } - - return total_inertia; -} - /** * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). * @@ -276,7 +214,7 @@ void fit(raft::resources const& handle, n_init = 1; } - auto best_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto best_centroids = n_init > 1 ? raft::make_device_matrix(handle, n_clusters, n_features) : raft::make_device_matrix(handle, 0, 0); T best_inertia = std::numeric_limits::max(); IdxT best_n_iter = 0; @@ -294,8 +232,6 @@ void fit(raft::resources const& handle, auto cluster_counts = raft::make_device_vector(handle, n_clusters); auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - rmm::device_scalar clusterCostD(stream); - // ---- Main n_init loop ---- for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = params; @@ -313,17 +249,16 @@ void fit(raft::resources const& handle, if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } // Reset per-iteration state - T priorClusteringCost = 0; + T prior_cluster_cost = 0; for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - T total_cost = 0; - raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); + auto cluster_cost = raft::make_device_scalar(handle, T{0}); auto centroids_const = raft::make_const_mdspan(centroids); @@ -377,19 +312,14 @@ void fit(raft::resources const& handle, centroid_sums.view(), cluster_counts.view()); - if (params.inertia_check) { + if (params.inertia_check || n_iter[0] == iter_params.max_iter) { // Compute cluster cost for this batch and accumulate - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - auto clusterCost_host = raft::make_host_scalar(0); - raft::copy(clusterCost_host.data_handle(), clusterCostD.data(), 1, stream); - raft::resource::sync_stream(handle, stream); - total_cost += clusterCost_host.data_handle()[0]; + cuvs::cluster::kmeans::detail::computeClusterCost(handle, + minClusterAndDistance.view(), + workspace, + cluster_cost.view(), + raft::value_op{}, + raft::add_op{}); } } @@ -413,11 +343,13 @@ void fit(raft::resources const& handle, bool done = false; if (params.inertia_check) { + raft::copy(inertia.data_handle(), cluster_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle); if (n_iter[0] > 1) { - T delta = total_cost / priorClusteringCost; + T delta = inertia[0] / prior_cluster_cost; if (delta > 1 - params.tol) done = true; } - priorClusteringCost = total_cost; + prior_cluster_cost = inertia[0]; } if (sqrdNormError < params.tol) done = true; @@ -429,13 +361,10 @@ void fit(raft::resources const& handle, } { - // If inertia_check was enabled, we already computed inertia during iterations - if (params.inertia_check) { - inertia[0] = priorClusteringCost; - } else { - auto centroids_const = raft::make_const_mdspan(centroids); - inertia[0] = compute_batched_host_inertia( - handle, X, batch_size, centroids_const, sample_weight); + // Inertia for the last iteration is always computed + if (!params.inertia_check) { + raft::copy(inertia.data_handle(), cluster_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle); } RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", @@ -443,17 +372,15 @@ void fit(raft::resources const& handle, n_init, static_cast(inertia[0])); - if (inertia[0] < best_inertia) { + if (n_init > 1 && inertia[0] < best_inertia) { best_inertia = inertia[0]; best_n_iter = n_iter[0]; raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - raft::resource::sync_stream(handle, stream); } } if (n_init > 1) { raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); - raft::resource::sync_stream(handle, stream); inertia[0] = best_inertia; n_iter[0] = best_n_iter; RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 24524b7207..a5160be677 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -552,12 +552,8 @@ void finalize_centroids(raft::resources const& handle, { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - // new_centroids = centroid_sums / cluster_counts (0 when count is 0) - raft::copy( - new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); - raft::linalg::matrix_vector_op(handle, - raft::make_const_mdspan(new_centroids), + raft::make_const_mdspan(centroid_sums), cluster_counts, new_centroids, raft::div_checkzero_op{}); From 0719636a27f6c4cc2e46cb6b2ffa4eda4038c142 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 13:00:42 -0700 Subject: [PATCH 77/81] fix compilation --- cpp/src/cluster/detail/kmeans_batched.cuh | 38 +++++++++++------------ 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 4921f644f2..d6c4bd85af 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -258,7 +258,7 @@ void fit(raft::resources const& handle, raft::matrix::fill(handle, centroid_sums.view(), T{0}); raft::matrix::fill(handle, cluster_counts.view(), T{0}); - auto cluster_cost = raft::make_device_scalar(handle, T{0}); + auto clustering_cost = raft::make_device_scalar(handle, T{0}); auto centroids_const = raft::make_const_mdspan(centroids); @@ -317,7 +317,7 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::detail::computeClusterCost(handle, minClusterAndDistance.view(), workspace, - cluster_cost.view(), + clustering_cost.view(), raft::value_op{}, raft::add_op{}); } @@ -343,7 +343,7 @@ void fit(raft::resources const& handle, bool done = false; if (params.inertia_check) { - raft::copy(inertia.data_handle(), cluster_cost.data_handle(), 1, stream); + raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); raft::resource::sync_stream(handle); if (n_iter[0] > 1) { T delta = inertia[0] / prior_cluster_cost; @@ -354,19 +354,18 @@ void fit(raft::resources const& handle, if (sqrdNormError < params.tol) done = true; - if (done) { + if (done || n_iter[0] == iter_params.max_iter) { RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + // Inertia for the last iteration is always computed + if (!params.inertia_check) { + raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle); + } break; } } { - // Inertia for the last iteration is always computed - if (!params.inertia_check) { - raft::copy(inertia.data_handle(), cluster_cost.data_handle(), 1, stream); - raft::resource::sync_stream(handle); - } - RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", seed_iter + 1, n_init, @@ -378,16 +377,15 @@ void fit(raft::resources const& handle, raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); } } - - if (n_init > 1) { - raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); - inertia[0] = best_inertia; - n_iter[0] = best_n_iter; - RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", - n_init, - static_cast(best_inertia), - best_n_iter); - } + } + if (n_init > 1) { + raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); + inertia[0] = best_inertia; + n_iter[0] = best_n_iter; + RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", + n_init, + static_cast(best_inertia), + best_n_iter); } } From 1a7b6445dfc579b22520b5e087ae0fd264f305c8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 13:33:39 -0700 Subject: [PATCH 78/81] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index d6c4bd85af..997d538f31 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -214,7 +214,9 @@ void fit(raft::resources const& handle, n_init = 1; } - auto best_centroids = n_init > 1 ? raft::make_device_matrix(handle, n_clusters, n_features) : raft::make_device_matrix(handle, 0, 0); + auto best_centroids = n_init > 1 + ? raft::make_device_matrix(handle, n_clusters, n_features) + : raft::make_device_matrix(handle, 0, 0); T best_inertia = std::numeric_limits::max(); IdxT best_n_iter = 0; From d4c53eed79b633614e8742e1293397629b7c4cc1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 13:53:46 -0700 Subject: [PATCH 79/81] address python reviews --- cpp/include/cuvs/cluster/kmeans.hpp | 17 ----------------- cpp/src/cluster/detail/kmeans_batched.cuh | 2 -- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 15 +++++++++++++++ python/cuvs/cuvs/tests/test_kmeans.py | 2 +- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 4e41ef2a52..c7141a5d47 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -163,11 +163,6 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @{ */ -/** - * @defgroup kmeans_batched Batched k-means for out-of-core / host data - * @{ - */ - /** * @brief Find clusters with k-means algorithm using batched processing of host data. * @@ -305,14 +300,6 @@ void predict(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia); -/** - * @} - */ - -/** - * @defgroup fit_predict_host K-Means Fit + Predict (host data) - * @{ - */ /** * @brief Fit k-means and predict cluster labels using batched processing of host data. @@ -349,10 +336,6 @@ void fit_predict(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); -/** - * @} - */ - /** * @brief Find clusters with k-means algorithm. * Initial centroids are chosen with k-means++ algorithm. Empty diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 997d538f31..6d293b6790 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -195,8 +195,6 @@ void fit(raft::resources const& handle, "centroids.extent(0) must equal n_clusters"); RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - raft::default_logger().set_level(params.verbosity); - RAFT_LOG_DEBUG("KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", static_cast(n_samples), static_cast(n_features), diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 1ef294051f..349656d6ca 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -274,6 +274,21 @@ def fit( not hasattr(X, '__cuda_array_interface__') ) + # Check that sample_weights has the same residency as X + if sample_weights is not None: + is_sample_weight_host = isinstance(sample_weights, np.ndarray) or ( + hasattr(sample_weights, '__array_interface__') and + not hasattr(sample_weights, '__cuda_array_interface__') + ) + if is_host != is_sample_weight_host: + raise ValueError( + "X and sample_weights must have the same memory residency " + "(both host or both device). X is {}, sample_weights is {}.".format( + "host" if is_host else "device", + "host" if is_sample_weight_host else "device" + ) + ) + if is_host: if not isinstance(X, np.ndarray): X = np.asarray(X) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index d614cb69ca..7df9dc86d7 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -79,7 +79,7 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): @pytest.mark.parametrize("n_rows", [1000, 5000]) @pytest.mark.parametrize("n_cols", [10, 100]) @pytest.mark.parametrize("n_clusters", [8, 16]) -@pytest.mark.parametrize("batch_size", [100, 500]) +@pytest.mark.parametrize("batch_size", [0, 100, 500]) @pytest.mark.parametrize("dtype", [np.float64]) def test_fit_host_matches_fit_device( n_rows, n_cols, n_clusters, batch_size, dtype From 8bec6d5f8875c98657ad5774c4b81c1a275b7042 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 13:57:07 -0700 Subject: [PATCH 80/81] style --- cpp/include/cuvs/cluster/kmeans.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index c7141a5d47..8834ce5e41 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -300,7 +300,6 @@ void predict(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia); - /** * @brief Fit k-means and predict cluster labels using batched processing of host data. * From e920ca0aa9f833187244f833fe30493f8c6d42df Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 14:06:50 -0700 Subject: [PATCH 81/81] rm batch_size struct --- c/src/cluster/kmeans.cpp | 4 +-- cpp/include/cuvs/cluster/kmeans.hpp | 32 +++++++++-------------- cpp/src/cluster/detail/kmeans_batched.cuh | 14 +++++----- cpp/tests/cluster/kmeans.cu | 6 ++--- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index f2161e3478..af3b49a5a3 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,7 +28,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; - kmeans_params.batched.batch_size = params.batch_size; + kmeans_params.batch_size = params.batch_size; return kmeans_params; } @@ -240,7 +240,7 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .inertia_check = cpp_params.inertia_check, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), - .batch_size = cpp_params.batched.batch_size}; + .batch_size = cpp_params.batch_size}; }); } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 8834ce5e41..7d3629e760 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -101,9 +101,9 @@ struct params : base_params { * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 * then don't tile the centroids * - * NB: These parameters are unrelated to batched_params.batch_size, which - * controls how many samples to transfer from host to device per batch when - * processing out-of-core data. + * NB: These parameters are unrelated to batch_size, which controls how many + * samples to transfer from host to device per batch when processing out-of-core + * data. */ int batch_samples = 1 << 15; @@ -118,19 +118,13 @@ struct params : base_params { bool inertia_check = false; /** - * @struct batched_params - * Parameters specific to batched k-means (host-data overloads of fit/predict). - * These parameters are only used when calling fit() or predict() with host_matrix_view data - * and are ignored by the device_matrix_view overloads. + * Number of samples to process per GPU batch when fitting with host data. + * When set to 0, defaults to n_samples (process all at once). + * Only used by the batched (host-data) code path and ignored by device-data + * overloads. + * Default: 0 (process all data at once). */ - struct batched_params { - /** - * Number of samples to process per GPU batch. - * When set to 0, a default batch size will be chosen automatically. - * Default: 0 (auto). - */ - int64_t batch_size = 0; - } batched; + int64_t batch_size = 0; }; /** @@ -168,7 +162,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * * This overload supports out-of-core computation where the dataset resides * on the host. Data is processed in GPU-sized batches, streaming from host to device. - * The batch size is controlled by params.batched.batch_size. + * The batch size is controlled by params.batch_size. * * @code{.cpp} * #include @@ -178,7 +172,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * raft::resources handle; * cuvs::cluster::kmeans::params params; * params.n_clusters = 100; - * params.batched.batch_size = 100000; + * params.batch_size = 100000; * int n_features = 15; * float inertia; * int n_iter; @@ -201,7 +195,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model. Batch size is read from - * params.batched.batch_size. + * params.batch_size. * @param[in] X Training instances on HOST memory. The data must * be in row-major format. * [dim = n_samples x n_features] @@ -268,7 +262,7 @@ void fit(raft::resources const& handle, * * Streams data from host to GPU in batches, assigns each sample to its nearest * centroid, and writes labels back to host memory. - * The batch size is controlled by params.batched.batch_size. + * The batch size is controlled by params.batch_size. * * @param[in] handle The raft handle. * @param[in] params Parameters for KMeans model. diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 6d293b6790..b5899aaf89 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -152,7 +152,7 @@ void accumulate_batch_centroids( * @tparam IdxT Index type (int, int64_t) * * @param[in] handle RAFT resources handle - * @param[in] params K-means parameters (batch size read from params.batched.batch_size) + * @param[in] params K-means parameters * @param[in] X Input data on HOST [n_samples x n_features] * @param[in] sample_weight Optional weights per sample (on host) * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] @@ -175,17 +175,17 @@ void fit(raft::resources const& handle, auto metric = params.metric; // Read batch_size from params; default to n_samples if 0 (auto) - IdxT batch_size = static_cast(params.batched.batch_size); + IdxT batch_size = static_cast(params.batch_size); if (batch_size <= 0) { batch_size = static_cast(n_samples); } RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); // Warn if user explicitly set batch_size larger than dataset size - if (params.batched.batch_size > 0 && static_cast(params.batched.batch_size) > n_samples) { + if (params.batch_size > 0 && static_cast(params.batch_size) > n_samples) { RAFT_LOG_WARN( "batch_size (%zu) is larger than dataset size (%zu). " "batch_size will be effectively clamped to %zu.", - static_cast(params.batched.batch_size), + static_cast(params.batch_size), static_cast(n_samples), static_cast(n_samples)); } @@ -411,17 +411,17 @@ void predict(raft::resources const& handle, auto n_clusters = params.n_clusters; // Read batch_size from params; default to n_samples if 0 (auto) - IdxT batch_size = static_cast(params.batched.batch_size); + IdxT batch_size = static_cast(params.batch_size); if (batch_size <= 0) { batch_size = static_cast(n_samples); } RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); // Warn if user explicitly set batch_size larger than dataset size - if (params.batched.batch_size > 0 && static_cast(params.batched.batch_size) > n_samples) { + if (params.batch_size > 0 && static_cast(params.batch_size) > n_samples) { RAFT_LOG_WARN( "batch_size (%zu) is larger than dataset size (%zu). " "batch_size will be effectively clamped to %zu.", - static_cast(params.batched.batch_size), + static_cast(params.batch_size), static_cast(n_samples), static_cast(n_samples)); } diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index e1bf7a8ea5..4e79d81a0a 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -446,7 +446,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam> h_sw = std::nullopt; std::vector h_sample_weight; @@ -616,8 +616,8 @@ class KmeansPredictBatchedTest : public ::testing::TestWithParam auto h_labels_view = raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); - T pred_inertia = 0; - params.batched.batch_size = std::min((int64_t)n_samples, (int64_t)256); + T pred_inertia = 0; + params.batch_size = std::min((int64_t)n_samples, (int64_t)256); cuvs::cluster::kmeans::predict( handle,