-
Notifications
You must be signed in to change notification settings - Fork 164
Add KDE kernel #1915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add KDE kernel #1915
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuvs/distance/distance.hpp> | ||
| #include <raft/core/resources.hpp> | ||
|
|
||
| namespace cuvs::distance { | ||
|
|
||
| /** | ||
| * @brief Density kernel type for Kernel Density Estimation. | ||
| * | ||
| * These are the smoothing kernels used in KDE — distinct from the dot-product | ||
| * kernels (RBF, Polynomial, etc.) in cuvs::distance::kernels used by SVMs. | ||
| */ | ||
| enum class DensityKernelType : int { | ||
| Gaussian = 0, | ||
| Tophat = 1, | ||
| Epanechnikov = 2, | ||
| Exponential = 3, | ||
| Linear = 4, | ||
| Cosine = 5 | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Compute log-density estimates for query points using kernel density estimation. | ||
| * | ||
| * Fuses pairwise distance computation, kernel evaluation, logsumexp reduction, | ||
| * and normalization into a single CUDA kernel pass. O(N+M) memory usage — | ||
| * the full N×M pairwise distance matrix is never materialised. | ||
| * | ||
| * Supports 13 distance metrics (all expressible as per-feature accumulation), | ||
| * 6 density kernel functions, float32 and float64, and both uniform and | ||
| * weighted training sets. | ||
| * | ||
| * When the query count is small relative to the number of GPU SMs, the | ||
| * training set is automatically split across a 2D grid (multi-pass mode) to | ||
| * keep the GPU fully utilised. Partial logsumexp results are merged by a | ||
| * reduction kernel. | ||
| * | ||
| * @tparam T float or double | ||
| * | ||
| * @param[in] handle RAFT resources handle for stream management | ||
| * @param[in] query Query points, row-major (n_query × n_features) | ||
| * @param[in] train Training points, row-major (n_train × n_features) | ||
| * @param[in] weights Per-training-point weights (n_train,), or nullptr for uniform | ||
| * @param[out] output Log-density estimates (n_query,) | ||
| * @param[in] n_query Number of query points | ||
| * @param[in] n_train Number of training points | ||
| * @param[in] n_features Dimensionality of the data | ||
| * @param[in] bandwidth Kernel bandwidth (must be > 0) | ||
| * @param[in] sum_weights Sum of sample weights (or n_train if uniform) | ||
| * @param[in] kernel Density kernel function | ||
| * @param[in] metric Distance metric | ||
| * @param[in] metric_arg Metric parameter (e.g. p for Minkowski; ignored otherwise) | ||
| */ | ||
| template <typename T> | ||
| void kde_score_samples(raft::resources const& handle, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense just call this "kde" or is there signficant enough post processing being done after calling this? |
||
| const T* query, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we accept mdspans here instead of raw pointers? That would also simplify the function signatures for the public APIs. Can still expose pointers at the cuml layer for now (for the python layer) if that's easier. |
||
| const T* train, | ||
| const T* weights, | ||
| T* output, | ||
| int n_query, | ||
| int n_train, | ||
| int n_features, | ||
| T bandwidth, | ||
| T sum_weights, | ||
| DensityKernelType kernel, | ||
| cuvs::distance::DistanceType metric, | ||
| T metric_arg); | ||
|
|
||
| extern template void kde_score_samples<float>(raft::resources const&, | ||
| const float*, | ||
| const float*, | ||
| const float*, | ||
| float*, | ||
| int, | ||
| int, | ||
| int, | ||
| float, | ||
| float, | ||
| DensityKernelType, | ||
| cuvs::distance::DistanceType, | ||
| float); | ||
|
|
||
| extern template void kde_score_samples<double>(raft::resources const&, | ||
| const double*, | ||
| const double*, | ||
| const double*, | ||
| double*, | ||
| int, | ||
| int, | ||
| int, | ||
| double, | ||
| double, | ||
| DensityKernelType, | ||
| cuvs::distance::DistanceType, | ||
| double); | ||
|
|
||
| } // namespace cuvs::distance | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we at least consolidate this struct w/ the existing
KernelTypestruct in distance.hpp?