Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5e5eecf
fragments, tags, instantiations, json and cmake
tarang-jain May 15, 2026
060aee2
Merge branch 'main' into jit-lto-pw
tarang-jain May 15, 2026
cf6ccb0
Update cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/comput…
tarang-jain May 18, 2026
9c0b716
Merge branch 'main' into jit-lto-pw
tarang-jain May 18, 2026
f9ff9e8
update arch_include
tarang-jain May 19, 2026
3faba38
Merge branch 'jit-lto-pw' of github.com:tarang-jain/cuvs into jit-lto-pw
tarang-jain May 19, 2026
7dce6e3
correct tags
tarang-jain May 19, 2026
e1f6a5c
update template
tarang-jain May 19, 2026
9e80919
reapply cmake comment
tarang-jain May 19, 2026
8fcb1bb
rm jit boolean
tarang-jain May 19, 2026
4ae7507
address pr reviews
tarang-jain May 19, 2026
e465980
update index type in json
tarang-jain May 19, 2026
150e705
Merge branch 'main' into jit-lto-pw
tarang-jain May 19, 2026
7df2f95
do not switch off clang
tarang-jain May 19, 2026
ec5ae32
Merge branch 'jit-lto-pw' of github.com:tarang-jain/cuvs into jit-lto-pw
tarang-jain May 19, 2026
15bb587
style check and header includes
tarang-jain May 19, 2026
a2e77d7
style check and header includes
tarang-jain May 19, 2026
5bb0293
compilation errors
tarang-jain May 19, 2026
5da125a
fix header includes
tarang-jain May 19, 2026
d5e942d
restore old version of dispatch-inl
tarang-jain May 19, 2026
60d3344
style and docs
tarang-jain May 19, 2026
680b55d
add the ifdef
tarang-jain May 20, 2026
86100d4
Merge branch 'main' into jit-lto-pw
tarang-jain May 20, 2026
9957163
Merge branch 'main' into jit-lto-pw
tarang-jain Jun 4, 2026
3ea9e91
Merge branch 'main' of https://github.com/rapidsai/cuvs into jit-lto-pw
tarang-jain Jun 4, 2026
e3e6d27
Merge branch 'jit-lto-pw' of github.com:tarang-jain/cuvs into jit-lto-pw
tarang-jain Jun 4, 2026
9964674
Merge branch 'main' into jit-lto-pw
tarang-jain Jun 5, 2026
846d48f
Merge branch 'main' into jit-lto-pw
tarang-jain Jun 5, 2026
884caf2
simplify json
tarang-jain Jun 6, 2026
4d4f178
Merge branch 'jit-lto-pw' of github.com:tarang-jain/cuvs into jit-lto-pw
tarang-jain Jun 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,93 @@ if(NOT BUILD_CPU_ONLY)
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
set(distance_ns "cuvs::distance::detail")
set(pairwise_matrix_jit_dir
"${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/pairwise_matrix/jit_lto_kernels"
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_fin_op_rbf_layout_@layout_abbrev@_veclen_@veclen@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_rbf_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY
"${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY
"${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_rbf_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY
"${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY
"${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@"
MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json"
KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY
"${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf"
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
)
endblock()

# Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/common_fragments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ namespace cuvs::neighbors::detail {

struct tag_f {};
struct tag_h {};
struct tag_d {};
Comment thread
tarang-jain marked this conversation as resolved.
struct tag_i8 {};
struct tag_u8 {};
struct tag_filter_none {};
struct tag_filter_bitset {};

struct tag_bitset_u32 {};

struct tag_index_i32 {};
struct tag_index_u32 {};
struct tag_index_i64 {};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

namespace cuvs::distance::detail {

struct tag_layout_row {};
struct tag_layout_col {};

struct tag_fin_op_identity {};
struct tag_fin_op_rbf {};

struct tag_distance_canberra {};
struct tag_distance_correlation {};
struct tag_distance_cosine {};
struct tag_distance_hamming_unexpanded {};
struct tag_distance_hellinger_expanded {};
struct tag_distance_jensen_shannon {};
struct tag_distance_kl_divergence {};
struct tag_distance_l1 {};
struct tag_distance_l2_expanded {};
struct tag_distance_l2_unexpanded {};
struct tag_distance_l_inf {};
struct tag_distance_lp_unexpanded {};
struct tag_distance_russel_rao {};

template <typename DistanceTag,
typename DataTag,
typename AccTag,
typename OutTag,
typename IndexTag,
typename FinOpTag,
typename LayoutTag,
int Veclen>
struct fragment_tag_pairwise_matrix {};

template <typename DistanceTag, typename DataTag, typename AccTag, typename IndexTag>
struct fragment_tag_compute_distance {};

template <typename DistanceTag,
typename DataTag,
typename AccTag,
typename IndexTag,
typename LayoutTag,
int Veclen>
struct fragment_tag_compute_distance_epilog {};

} // namespace cuvs::distance::detail
2 changes: 0 additions & 2 deletions cpp/src/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include "distance_ops/all_ops.cuh"
#include "pairwise_matrix/dispatch.cuh"
#include "pairwise_matrix/dispatch_sm60.cuh"
#include "pairwise_matrix/dispatch_sm80.cuh"
#include <cuvs/distance/distance.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/correlation.cuh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/operators.hpp> // raft::sqrt
#include <raft/util/cuda_dev_essentials.cuh> // DI

#include <cuda_fp16.h>
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/hellinger.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once
#include <raft/core/operators.hpp> // raft::sqrt
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace cuvs::distance::detail::ops {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/l1.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once
#include <raft/core/operators.hpp> // raft::abs
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace cuvs::distance::detail::ops {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/l2_unexp.cuh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/operators.hpp> // raft::sqrt
#include <raft/util/cuda_dev_essentials.cuh> // DI

#include <cuda_fp16.h>
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/l_inf.cuh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/operators.hpp> // raft::abs, raft::max
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace cuvs::distance::detail::ops {
Expand Down
20 changes: 19 additions & 1 deletion cpp/src/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT
Comment thread
KyleFromNVIDIA marked this conversation as resolved.
#include "pairwise_matrix/jit_lto_kernels/device_functions.cuh"
#endif
#include <raft/linalg/contractions.cuh> // raft::linalg::Contractions_NT
#include <raft/util/cuda_dev_essentials.cuh> // ceildiv
#include <raft/util/cuda_rt_essentials.hpp> // RAFT_CUDA_TRY
Expand Down Expand Up @@ -150,7 +153,12 @@ struct PairwiseDistances : public BaseClass {
// Calculate distance_op epilog.
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT
compute_distance_epilog<Policy, OpT, AccT, IdxT>(
distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m);
#else
distance_op.template epilog<Policy>(acc, regxn, regyn, tile_idx_n, tile_idx_m);
#endif
// And any possible additional epilogs
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
Expand All @@ -159,7 +167,12 @@ struct PairwiseDistances : public BaseClass {
// Calculate distance_op epilog.
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT
compute_distance_epilog<Policy, OpT, AccT, IdxT>(
distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
#else
distance_op.template epilog<Policy>(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
#endif
// And any possible additional epilogs
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
Expand Down Expand Up @@ -203,7 +216,12 @@ struct PairwiseDistances : public BaseClass {
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT
compute_distance<OpT, DataT, AccT, IdxT>(
distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]);
#else
distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]);
#endif
}
}
}
Expand Down
Loading
Loading