diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d4e9f2c6b4..e1d995bd9c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -538,6 +538,7 @@ if(NOT BUILD_CPU_ONLY) src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu src/distance/distance.cu + src/distance/kde.cu src/distance/pairwise_distance.cu src/distance/sparse_distance.cu src/neighbors/all_neighbors/all_neighbors.cu diff --git a/cpp/include/cuvs/distance/kde.hpp b/cpp/include/cuvs/distance/kde.hpp new file mode 100644 index 0000000000..4d78d25218 --- /dev/null +++ b/cpp/include/cuvs/distance/kde.hpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::distance { + +/** + * @brief Density kernel type for Kernel Density Estimation. + * + * These are the smoothing kernels used in KDE — distinct from the dot-product + * kernels (RBF, Polynomial, etc.) in cuvs::distance::kernels used by SVMs. + */ +enum class DensityKernelType : int { + Gaussian = 0, + Tophat = 1, + Epanechnikov = 2, + Exponential = 3, + Linear = 4, + Cosine = 5 +}; + +/** + * @brief Compute log-density estimates for query points using kernel density estimation. + * + * Fuses pairwise distance computation, kernel evaluation, logsumexp reduction, + * and normalization into a single CUDA kernel pass. O(N+M) memory usage — + * the full N×M pairwise distance matrix is never materialised. + * + * Supports 13 distance metrics (all expressible as per-feature accumulation), + * 6 density kernel functions, float32 and float64, and both uniform and + * weighted training sets. + * + * When the query count is small relative to the number of GPU SMs, the + * training set is automatically split across a 2D grid (multi-pass mode) to + * keep the GPU fully utilised. Partial logsumexp results are merged by a + * reduction kernel. + * + * @tparam T float or double + * + * @param[in] handle RAFT resources handle for stream management + * @param[in] query Query points, row-major (n_query × n_features) + * @param[in] train Training points, row-major (n_train × n_features) + * @param[in] weights Per-training-point weights (n_train,), or nullptr for uniform + * @param[out] output Log-density estimates (n_query,) + * @param[in] n_query Number of query points + * @param[in] n_train Number of training points + * @param[in] n_features Dimensionality of the data + * @param[in] bandwidth Kernel bandwidth (must be > 0) + * @param[in] sum_weights Sum of sample weights (or n_train if uniform) + * @param[in] kernel Density kernel function + * @param[in] metric Distance metric + * @param[in] metric_arg Metric parameter (e.g. p for Minkowski; ignored otherwise) + */ +template +void kde_score_samples(raft::resources const& handle, + const T* query, + const T* train, + const T* weights, + T* output, + int n_query, + int n_train, + int n_features, + T bandwidth, + T sum_weights, + DensityKernelType kernel, + cuvs::distance::DistanceType metric, + T metric_arg); + +extern template void kde_score_samples(raft::resources const&, + const float*, + const float*, + const float*, + float*, + int, + int, + int, + float, + float, + DensityKernelType, + cuvs::distance::DistanceType, + float); + +extern template void kde_score_samples(raft::resources const&, + const double*, + const double*, + const double*, + double*, + int, + int, + int, + double, + double, + DensityKernelType, + cuvs::distance::DistanceType, + double); + +} // namespace cuvs::distance diff --git a/cpp/src/distance/kde.cu b/cpp/src/distance/kde.cu new file mode 100644 index 0000000000..c8c13f99a2 --- /dev/null +++ b/cpp/src/distance/kde.cu @@ -0,0 +1,679 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include + +namespace cuvs::distance { + +// ============================================================================ +// Distance accumulator ops — decomposed into init / accumulate / finalize +// so the tiled kernel can tile over features while accumulating partial +// distances in registers. +// +// Each specialisation defines: +// N_ACC — number of accumulator values per distance computation +// init(acc) — zero the accumulators +// accumulate(acc, a, b, p) — per-feature accumulation +// finalize(acc, d, p) — convert accumulators to final scalar distance +// ============================================================================ + +template +struct DistOp; + +// euclidean: sqrt(sum((a-b)^2)) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + T d = a - b; + acc[0] += d * d; + } + inline __device__ static T finalize(T* acc, int, T) { return sqrt(acc[0]); } +}; + +// sqeuclidean: sum((a-b)^2) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + T d = a - b; + acc[0] += d * d; + } + inline __device__ static T finalize(T* acc, int, T) { return acc[0]; } +}; + +// manhattan: sum(|a-b|) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) { acc[0] += abs(a - b); } + inline __device__ static T finalize(T* acc, int, T) { return acc[0]; } +}; + +// chebyshev: max(|a-b|) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + acc[0] = max(acc[0], abs(a - b)); + } + inline __device__ static T finalize(T* acc, int, T) { return acc[0]; } +}; + +// minkowski: (sum(|a-b|^p))^(1/p) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T p) { acc[0] += pow(abs(a - b), p); } + inline __device__ static T finalize(T* acc, int, T p) { return pow(acc[0], T(1) / p); } +}; + +// cosine: 1 - dot(a,b)/(||a||*||b||) +// acc[0]=dot, acc[1]=||a||^2, acc[2]=||b||^2 +template +struct DistOp { + static constexpr int N_ACC = 3; + inline __device__ static void init(T* acc) { acc[0] = acc[1] = acc[2] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + acc[0] += a * b; + acc[1] += a * a; + acc[2] += b * b; + } + inline __device__ static T finalize(T* acc, int, T) + { + T denom = sqrt(acc[1]) * sqrt(acc[2]); + return (denom > T(0)) ? (T(1) - acc[0] / denom) : T(0); + } +}; + +// correlation: cosine on mean-centred vectors (single-pass via sum identities) +// acc[0]=sum_a, acc[1]=sum_b, acc[2]=sum_a2, acc[3]=sum_b2, acc[4]=sum_ab +template +struct DistOp { + static constexpr int N_ACC = 5; + inline __device__ static void init(T* acc) { acc[0] = acc[1] = acc[2] = acc[3] = acc[4] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + acc[0] += a; + acc[1] += b; + acc[2] += a * a; + acc[3] += b * b; + acc[4] += a * b; + } + inline __device__ static T finalize(T* acc, int d, T) + { + T ma = acc[0] / T(d); + T mb = acc[1] / T(d); + T dot = acc[4] - T(d) * ma * mb; + T na = acc[2] - T(d) * ma * ma; + T nb = acc[3] - T(d) * mb * mb; + T den = sqrt(na) * sqrt(nb); + return (den > T(0)) ? (T(1) - dot / den) : T(0); + } +}; + +// canberra: sum(|a-b|/(|a|+|b|)) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + const T diff = abs(a - b); + const T add = abs(a) + abs(b); + acc[0] += ((add != T(0)) * diff / (add + (add == T(0)))); + } + inline __device__ static T finalize(T* acc, int, T) { return acc[0]; } +}; + +// hellinger: sqrt(1 - sum(sqrt(a)*sqrt(b))) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) { acc[0] += sqrt(a) * sqrt(b); } + inline __device__ static T finalize(T* acc, int, T) + { + const T val = T(1) - acc[0]; + return sqrt((!signbit(val)) * val); + } +}; + +// jensen-shannon +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + const T m = T(0.5) * (a + b); + const bool mz = (m == T(0)); + const T logM = (!mz) * log(m + mz); + const bool xz = (a == T(0)); + const bool yz = (b == T(0)); + acc[0] += (-a * (logM - log(a + xz))) + (-b * (logM - log(b + yz))); + } + inline __device__ static T finalize(T* acc, int, T) { return sqrt(T(0.5) * acc[0]); } +}; + +// hamming: count(a!=b)/d +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) { acc[0] += (a != b); } + inline __device__ static T finalize(T* acc, int d, T) { return acc[0] / T(d); } +}; + +// KL divergence: sum(a*log(a/b)) +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) + { + if (a > T(0) && b > T(0)) { acc[0] += a * log(a / b); } + } + inline __device__ static T finalize(T* acc, int, T) { return acc[0]; } +}; + +// Russell-Rao: (d - sum(a*b)) / d +template +struct DistOp { + static constexpr int N_ACC = 1; + inline __device__ static void init(T* acc) { acc[0] = T(0); } + inline __device__ static void accumulate(T* acc, T a, T b, T) { acc[0] += a * b; } + inline __device__ static T finalize(T* acc, int d, T) { return (T(d) - acc[0]) / T(d); } +}; + +// ============================================================================ +// Log-kernel traits — one specialisation per DensityKernelType +// ============================================================================ + +template +struct LogKernel; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) { return -(x * x) / (T(2) * h * h); } +}; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) + { + return (x < h) ? T(0) : cuda::std::numeric_limits::lowest(); + } +}; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) + { + T z = max(T(1) - (x * x) / (h * h), T(1e-30)); + return (x < h) ? log(z) : cuda::std::numeric_limits::lowest(); + } +}; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) { return -x / h; } +}; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) + { + T z = max(T(1) - x / h, T(1e-30)); + return (x < h) ? log(z) : cuda::std::numeric_limits::lowest(); + } +}; + +template +struct LogKernel { + inline __device__ static T eval(T x, T h) + { + T z = max(cos(T(0.5) * T(M_PI) * x / h), T(1e-30)); + return (x < h) ? log(z) : cuda::std::numeric_limits::lowest(); + } +}; + +// ============================================================================ +// Host-side normalization functions (mirror the Python implementations) +// ============================================================================ + +template +T logVn(int n) +{ + return T(0.5) * n * std::log(T(M_PI)) - std::lgamma(T(0.5) * n + T(1)); +} + +template +T logSn(int n) +{ + return std::log(T(2) * T(M_PI)) + logVn(n - 1); +} + +template +T norm_factor(DensityKernelType kernel, T h, int d) +{ + T factor; + switch (kernel) { + case DensityKernelType::Gaussian: factor = T(0.5) * d * std::log(T(2) * T(M_PI)); break; + case DensityKernelType::Tophat: factor = logVn(d); break; + case DensityKernelType::Epanechnikov: + factor = logVn(d) + std::log(T(2) / T(d + 2)); + break; + case DensityKernelType::Exponential: factor = logSn(d - 1) + std::lgamma(T(d)); break; + case DensityKernelType::Linear: factor = logVn(d) - std::log(T(d + 1)); break; + case DensityKernelType::Cosine: { + // Compute integral_0^1 cos(pi/2 * t) * t^{d-1} dt using the recurrence: + // I_n = (2/pi) - n*(n-1)*(2/pi)^2 * I_{n-2} + // I_0 = 2/pi, I_1 = 2/pi - (2/pi)^2 + // This is derived from repeated integration by parts; both sin and cos + // boundary terms must be included (the old loop-based formula missed the + // cos terms at t=0 for even d). + const T two_over_pi = T(2) / T(M_PI); + const T two_over_pi_sq = two_over_pi * two_over_pi; + T I_prev = two_over_pi; // I_0 + T I_curr = two_over_pi - two_over_pi_sq; // I_1 + const int n = d - 1; // need I_n + if (n == 0) { + factor = std::log(I_prev) + logSn(d - 1); + } else { + for (int j = 2; j <= n; ++j) { + T I_next = two_over_pi - T(j) * T(j - 1) * two_over_pi_sq * I_prev; + I_prev = I_curr; + I_curr = I_next; + } + factor = std::log(I_curr) + logSn(d - 1); + } + } break; + default: throw std::invalid_argument("Unsupported kernel type"); + } + return factor + d * std::log(h); +} + +// ============================================================================ +// Tiled CUDA kernel — edistance-style CELL_TILE + FEAT_TILE optimisation +// +// One thread per query point. Train vectors are cooperatively loaded into +// shared memory in tiles of [FEAT_TILE][CELL_TILE]. Each thread accumulates +// distances from its query point to CELL_TILE train points simultaneously, +// amortising query feature reads and reducing global memory traffic. +// +// Supports both single-pass (full train set → final output) and multi-pass +// (train subset → partial logsumexp) modes. Multi-pass mode is used when +// the query count is too small to fill the GPU, parallelising over the +// train dimension via a 2D grid. +// ============================================================================ + +template +__global__ void kde_tiled_kernel(const T* __restrict__ query, + const T* __restrict__ train, + const T* __restrict__ weights, + T* __restrict__ out_a, + T* __restrict__ out_b, + int n_query, + int n_train, + int d, + T bandwidth, + T metric_arg, + T log_norm, + int train_chunk, + int feat_tile) +{ + using DOp = DistOp; + + extern __shared__ char smem_raw[]; + T* smem_train = reinterpret_cast(smem_raw); // [feat_tile][CELL_TILE] + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + const bool valid = (i < n_query); + + constexpr int N_ACC = DOp::N_ACC; + + // Determine train range for this block + const int j_begin = blockIdx.y * train_chunk; + const int j_end = min(j_begin + train_chunk, n_train); + + // Initialize to lowest() (not -inf) so that out-of-support points returning + // lowest() don't produce 0*exp(+inf)=NaN via exp(-inf - lowest()) = exp(+inf). + T running_max = cuda::std::numeric_limits::lowest(); + T running_sum = T(0); + + // Tile over train points in groups of CELL_TILE + for (int j_base = j_begin; j_base < j_end; j_base += CELL_TILE) { + const int cells_in_tile = min(CELL_TILE, j_end - j_base); + + // Per-train-point accumulators in registers + T acc[CELL_TILE * N_ACC]; +#pragma unroll + for (int c = 0; c < CELL_TILE; ++c) + DOp::init(&acc[c * N_ACC]); + + // Tile over features + for (int feat_base = 0; feat_base < d; feat_base += feat_tile) { + const int feats_in_tile = min(feat_tile, d - feat_base); + + // Cooperatively load train tile into shared memory: smem[feat][cell] + const int total_elems = feat_tile * CELL_TILE; + for (int idx = threadIdx.x; idx < total_elems; idx += blockDim.x) { + const int cell = idx / feat_tile; + const int feat = idx % feat_tile; + T val = T(0); + if (cell < cells_in_tile && feat < feats_in_tile) { + val = train[static_cast(j_base + cell) * d + feat_base + feat]; + } + smem_train[feat * CELL_TILE + cell] = val; + } + + __syncthreads(); + + if (valid) { + for (int f = 0; f < feats_in_tile; ++f) { + const T val_q = query[static_cast(i) * d + feat_base + f]; +#pragma unroll + for (int c = 0; c < CELL_TILE; ++c) { + const T val_t = smem_train[f * CELL_TILE + c]; + DOp::accumulate(&acc[c * N_ACC], val_q, val_t, metric_arg); + } + } + } + + __syncthreads(); + } + + // Finalize distances and fold into streaming logsumexp + if (valid) { +#pragma unroll + for (int c = 0; c < CELL_TILE; ++c) { + if (c >= cells_in_tile) break; + T dist = DOp::finalize(&acc[c * N_ACC], d, metric_arg); + T log_k = LogKernel::eval(dist, bandwidth); + if (weights) log_k += log(weights[j_base + c]); + + if (log_k > running_max) { + running_sum = running_sum * exp(running_max - log_k) + T(1); + running_max = log_k; + } else { + running_sum += exp(log_k - running_max); + } + } + } + } + + if (valid) { + if (out_b == nullptr) { + // Single-pass: write final log-probability + out_a[i] = log(running_sum) + running_max - log_norm; + } else { + // Multi-pass: write partial (max, sum) for later reduction + const size_t idx = static_cast(i) * gridDim.y + blockIdx.y; + out_a[idx] = running_max; + out_b[idx] = running_sum; + } + } +} + +// ============================================================================ +// Reduction kernel — merges partial logsumexp results from multi-pass +// ============================================================================ + +template +__global__ void kde_reduce_kernel(const T* __restrict__ partial_max, + const T* __restrict__ partial_sum, + T* __restrict__ output, + int n_query, + int n_blocks, + T log_norm) +{ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_query) return; + + T rmax = cuda::std::numeric_limits::lowest(); + T rsum = T(0); + + for (int b = 0; b < n_blocks; ++b) { + const size_t idx = static_cast(i) * n_blocks + b; + const T pm = partial_max[idx]; + const T ps = partial_sum[idx]; + if (pm > rmax) { + rsum = rsum * exp(rmax - pm) + ps; + rmax = pm; + } else { + rsum += ps * exp(pm - rmax); + } + } + output[i] = log(rsum) + rmax - log_norm; +} + +// ============================================================================ +// Double dispatch: runtime enum → compile-time template +// ============================================================================ + +template +void dispatch_metric(cuvs::distance::DistanceType metric, Fn&& fn) +{ + using DT = cuvs::distance::DistanceType; + switch (metric) { + case DT::L2SqrtUnexpanded: fn(std::integral_constant{}); break; + case DT::L2Expanded: fn(std::integral_constant{}); break; + case DT::L1: fn(std::integral_constant{}); break; + case DT::Linf: fn(std::integral_constant{}); break; + case DT::LpUnexpanded: fn(std::integral_constant{}); break; + case DT::CosineExpanded: fn(std::integral_constant{}); break; + case DT::CorrelationExpanded: + fn(std::integral_constant{}); + break; + case DT::Canberra: fn(std::integral_constant{}); break; + case DT::HellingerExpanded: fn(std::integral_constant{}); break; + case DT::JensenShannon: fn(std::integral_constant{}); break; + case DT::HammingUnexpanded: fn(std::integral_constant{}); break; + case DT::KLDivergence: fn(std::integral_constant{}); break; + case DT::RusselRaoExpanded: fn(std::integral_constant{}); break; + default: throw std::invalid_argument("Unsupported distance metric for KDE"); + } +} + +template +void dispatch_kernel(DensityKernelType kernel, Fn&& fn) +{ + switch (kernel) { + case DensityKernelType::Gaussian: + fn(std::integral_constant{}); + break; + case DensityKernelType::Tophat: + fn(std::integral_constant{}); + break; + case DensityKernelType::Epanechnikov: + fn(std::integral_constant{}); + break; + case DensityKernelType::Exponential: + fn(std::integral_constant{}); + break; + case DensityKernelType::Linear: + fn(std::integral_constant{}); + break; + case DensityKernelType::Cosine: + fn(std::integral_constant{}); + break; + default: throw std::invalid_argument("Unsupported kernel type for KDE"); + } +} + +// ============================================================================ +// Implementation: launches the tiled kernel (1-pass or 2-pass) +// ============================================================================ + +template +void kde_score_samples(raft::resources const& handle, + const T* query, + const T* train, + const T* weights, + T* output, + int n_query, + int n_train, + int d, + T bandwidth, + T sum_weights, + DensityKernelType kernel, + cuvs::distance::DistanceType metric, + T metric_arg) +{ + RAFT_EXPECTS(n_query > 0, "n_query must be > 0"); + RAFT_EXPECTS(n_train > 0, "n_train must be > 0"); + RAFT_EXPECTS(d > 0, "n_features must be > 0"); + RAFT_EXPECTS(bandwidth > T(0), "bandwidth must be > 0"); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + T log_norm = std::log(sum_weights) + norm_factor(kernel, bandwidth, d); + + // Cap feature tile to the actual dimension to avoid wasted shared memory + // and cooperative load cycles for low-dimensional data (e.g. 2D embeddings). + const int feat_tile = min(64, d); + // 512 threads for float32 (more cooperative load throughput, better GPU fill). + // 256 for float64 to avoid exceeding per-block register limits with + // CELL_TILE=64 double-precision accumulators (64×2 regs × 512 threads > 65536). + const int threads = (sizeof(T) == 4) ? 512 : 256; + int n_query_blocks = (n_query + threads - 1) / threads; + + dispatch_metric(metric, [&](auto metric_tag) { + dispatch_kernel(kernel, [&](auto kernel_tag) { + constexpr auto M = decltype(metric_tag)::value; + constexpr auto K = decltype(kernel_tag)::value; + + // Adapt CELL_TILE to keep accumulator register pressure under ~128 regs. + constexpr int N_ACC = DistOp::N_ACC; + constexpr int ACC_REGS = sizeof(T) / 4; + constexpr int RAW_TILE = 128 / (N_ACC * ACC_REGS); + constexpr int CELL_TILE = RAW_TILE >= 64 ? 64 + : RAW_TILE >= 32 ? 32 + : RAW_TILE >= 16 ? 16 + : RAW_TILE >= 8 ? 8 + : 4; + + size_t smem_bytes = feat_tile * CELL_TILE * sizeof(T); + + // Determine whether to split the train dimension across blocks. + // When n_query is small the GPU is underutilised; splitting the train + // set across a 2D grid exposes more parallelism. + int dev, sm_count; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev)); + int target_blocks = sm_count * 4; + int n_train_blocks = max(1, target_blocks / n_query_blocks); + int min_train_chunk = CELL_TILE * 4; + n_train_blocks = min(n_train_blocks, max(1, n_train / min_train_chunk)); + + if (n_train_blocks <= 1) { + // Single-pass: process all train points, write directly to output + dim3 grid(n_query_blocks); + kde_tiled_kernel + <<>>(query, + train, + weights, + output, + static_cast(nullptr), + n_query, + n_train, + d, + bandwidth, + metric_arg, + log_norm, + n_train, + feat_tile); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } else { + // Multi-pass: split train dimension, write partial (max, sum), then reduce + int train_chunk = (n_train + n_train_blocks - 1) / n_train_blocks; + // Round up to CELL_TILE for clean tiling + train_chunk = ((train_chunk + CELL_TILE - 1) / CELL_TILE) * CELL_TILE; + // Recompute actual number of blocks after rounding + n_train_blocks = (n_train + train_chunk - 1) / train_chunk; + + size_t buf_elems = static_cast(n_query) * n_train_blocks; + rmm::device_uvector partial_max(buf_elems, stream); + rmm::device_uvector partial_sum(buf_elems, stream); + + dim3 grid(n_query_blocks, n_train_blocks); + kde_tiled_kernel + <<>>(query, + train, + weights, + partial_max.data(), + partial_sum.data(), + n_query, + n_train, + d, + bandwidth, + metric_arg, + log_norm, + train_chunk, + feat_tile); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + kde_reduce_kernel<<>>( + partial_max.data(), partial_sum.data(), output, n_query, n_train_blocks, log_norm); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + }); + }); +} + +// Explicit instantiations +template void kde_score_samples(raft::resources const&, + const float*, + const float*, + const float*, + float*, + int, + int, + int, + float, + float, + DensityKernelType, + cuvs::distance::DistanceType, + float); + +template void kde_score_samples(raft::resources const&, + const double*, + const double*, + const double*, + double*, + int, + int, + int, + double, + double, + DensityKernelType, + cuvs::distance::DistanceType, + double); + +} // namespace cuvs::distance