From 5e5eecf5579e973ecc550e6d7e2845a9dee73021 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Fri, 15 May 2026 15:59:09 -0700 Subject: [PATCH 01/18] fragments, tags, instantiations, json and cmake --- cpp/CMakeLists.txt | 90 +++++- cpp/src/distance/detail/distance.cuh | 2 - .../detail/pairwise_distance_base.cuh | 25 +- .../detail/pairwise_matrix/dispatch-inl.cuh | 85 +---- .../pairwise_matrix/dispatch_matrix.json | 28 +- .../pairwise_matrix/dispatch_rbf_inst.cu.in | 2 +- .../pairwise_matrix/dispatch_rbf_matrix.json | 2 +- .../detail/pairwise_matrix/dispatch_sm60.cuh | 74 ----- .../compute_distance_epilog_kernel.cu.in | 40 +++ .../compute_distance_epilog_matrix.json | 290 ++++++++++++++++++ .../compute_distance_epilog_rbf_matrix.json | 128 ++++++++ .../compute_distance_kernel.cu.in | 35 +++ .../compute_distance_matrix.json | 155 ++++++++++ .../compute_distance_rbf_matrix.json | 45 +++ .../jit_lto_kernels/device_functions.cuh | 21 ++ .../jit_lto_kernels/pairwise_matrix_jit.cuh | 213 +++++++++++++ .../pairwise_matrix_kernel.cu.in | 77 +++++ .../pairwise_matrix_planner.hpp | 73 +++++ .../pairwise_matrix_rbf_kernel.cu.in | 78 +++++ .../jit_lto_kernels/registration_tags.hpp | 58 ++++ .../detail/pairwise_matrix/kernel_sm60.cuh | 144 --------- 21 files changed, 1343 insertions(+), 322 deletions(-) delete mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in create mode 100644 cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp delete mode 100644 cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ac682a9cfb..12b37c61ac 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -768,11 +768,95 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + set(distance_ns "cuvs::distance::detail") + set(pairwise_matrix_jit_dir + "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/pairwise_matrix/jit_lto_kernels" + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_fin_op_rbf_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_rbf_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) endblock() - # Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a - # way to do an item-wise transform on a list after computing the matrix product and before - # configuring the file generate_inst_matrix( pairwise_matrix_dispatch_inst_files MATRIX_JSON_FILE diff --git a/cpp/src/distance/detail/distance.cuh b/cpp/src/distance/detail/distance.cuh index ab06bc3ba1..a8b73d58e3 100644 --- a/cpp/src/distance/detail/distance.cuh +++ b/cpp/src/distance/detail/distance.cuh @@ -7,8 +7,6 @@ #include "distance_ops/all_ops.cuh" #include "pairwise_matrix/dispatch.cuh" -#include "pairwise_matrix/dispatch_sm60.cuh" -#include "pairwise_matrix/dispatch_sm80.cuh" #include #include #include diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index b2e886807d..4300901179 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include "pairwise_matrix/jit_lto_kernels/device_functions.cuh" #include // raft::linalg::Contractions_NT #include // ceildiv #include // RAFT_CUDA_TRY @@ -54,7 +55,8 @@ template > + typename BaseClass = raft::linalg::Contractions_NT, + bool useJitDeviceFunctions = false> struct PairwiseDistances : public BaseClass { // Get accumulation type from distance_op using AccT = typename OpT::AccT; @@ -150,7 +152,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + if constexpr (useJitDeviceFunctions) { + compute_distance_epilog( + distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { @@ -159,7 +166,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + if constexpr (useJitDeviceFunctions) { + compute_distance_epilog( + distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } else { + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } @@ -203,7 +215,12 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); + if constexpr (useJitDeviceFunctions) { + compute_distance( + distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]); + } else { + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); + } } } } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index 79cbb78f43..9cc9f29710 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -4,51 +4,16 @@ */ #pragma once -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/detail/pairwise_matrix/. Previously, - * cuvs::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The cuvs::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. +/* This file provides concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * cuvs::distance::detail::distance was instantiated. The function necessarily + * required a large set of include files, which slowed down the build. */ -#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op -#include "../pairwise_matrix/dispatch_sm60.cuh" // dispatch_sm60 -#include "../pairwise_matrix/params.cuh" // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in cuvs/distance/detail/distance.cuh -// and src/distance/detail/pairwise_matrix/dispatch_*.cu. +#include "../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh" // pairwise_matrix_jit_dispatch namespace cuvs::distance::detail { -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - template (); - - if constexpr (cutlass_op_unavailable) { - // Always execute legacy kernels when no cutlass op is available - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - - // TODO: the cutlass doesn't support the odd `k` on half DataT. - bool if_unsupported_on_half = (sizeof(DataT) == 2) && ((k % 2) != 0); - - if (if_unsupported_on_half) { - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else if (cutlass_range.contains(runtime_arch) && !if_unsupported_on_half) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } + pairwise_matrix_jit_dispatch(distance_op, params, stream); } }; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json index 5412fc897f..ebafa5c245 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json @@ -30,67 +30,67 @@ { "op_type": "cuvs::distance::detail::ops::canberra_distance_op", "op_abbrev": "canberra", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::correlation_distance_op", "op_abbrev": "correlation", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::cosine_distance_op", "op_abbrev": "cosine", - "arch_includes": "#include \n#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::hamming_distance_op", "op_abbrev": "hamming_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", "op_abbrev": "hellinger_expanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", "op_abbrev": "jensen_shannon", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::kl_divergence_op", "op_abbrev": "kl_divergence", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l1_distance_op", "op_abbrev": "l1", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "#include \n#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", "op_abbrev": "l2_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", "op_abbrev": "l_inf", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", "op_abbrev": "lp_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", "op_abbrev": "russel_rao", - "arch_includes": "#include " + "arch_includes": "" } ], "index_type": "int", @@ -119,7 +119,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "#include \n#include " + "arch_includes": "" } ], "index_type": "int64_t", diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in index 879b221e4c..0e65689603 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in @@ -14,7 +14,7 @@ using data_t = @data_type@; using acc_t = @acc_type@; using out_t = @out_type@; - using fin_op_t = cuvs::distance::kernels::rbf_fin_op; + using fin_op_t = cuvs::distance::kernels::rbf_fin_op; using index_t = @index_type@; using op_t = @op_type@; } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json index 876e824733..12213602f1 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json @@ -31,7 +31,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", "op_abbrev": "l2_unexpanded", - "arch_includes": "#include " + "arch_includes": "" } ] } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh deleted file mode 100644 index 034aabef3e..0000000000 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ /dev/null @@ -1,74 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "dispatch_layout.cuh" // dispatch_layout -#include "kernel_sm60.cuh" // pairwise_matrix_sm60_wrapper -#include // raft::linalg::Policy4x4 - -#include // std::min - -namespace cuvs::distance::detail { - -template -pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and returns - // the corresponding kernel wrapper. The wrapper contains the launch - // parameters of the kernel: a pointer to the kernel function, grid size, - // block size, and shared memory size. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - - using RowPolicy = typename raft::linalg::Policy4x4::Policy; - using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; - using Policy = typename std::conditional::type; - - auto wrapper = - make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); - - return wrapper; - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - return dispatch_layout(params.is_row_major, vec_len, f); -} - -template -void pairwise_matrix_sm60_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); - - wrapper.launch(distance_op, params, stream); -} - -} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in new file mode 100644 index 0000000000..d0b4d64645 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include <@header_file@> +#include +#include + +namespace { + +// clang-format off +using data_t = @data_type@; +using acc_t = @acc_type@; +using index_t = @index_type@; +using op_t = @op_type@; +using policy_t = typename raft::linalg::Policy4x4::@policy_type@; +// clang-format on + +} // namespace + +namespace cuvs::distance::detail { + +template <> +__device__ void compute_distance_epilog( + op_t distance_op, + acc_t acc[policy_t::AccRowsPerTh][policy_t::AccColsPerTh], + acc_t* regxn, + acc_t* regyn, + index_t grid_stride_x, + index_t grid_stride_y) +{ + distance_op.template epilog(acc, regxn, regyn, grid_stride_x, grid_stride_y); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json new file mode 100644 index 0000000000..afe37ce712 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json @@ -0,0 +1,290 @@ +[ + { + "_distance": [ + { + "distance_name": "canberra", + "distance_abbrev": "canberra", + "op_type": "cuvs::distance::detail::ops::canberra_distance_op", + "header_file": "distance/detail/distance_ops/canberra.cuh" + }, + { + "distance_name": "correlation", + "distance_abbrev": "correlation", + "op_type": "cuvs::distance::detail::ops::correlation_distance_op", + "header_file": "distance/detail/distance_ops/correlation.cuh" + }, + { + "distance_name": "cosine", + "distance_abbrev": "cosine", + "op_type": "cuvs::distance::detail::ops::cosine_distance_op", + "header_file": "distance/detail/distance_ops/cosine.cuh" + }, + { + "distance_name": "hamming_unexpanded", + "distance_abbrev": "hamming_unexpanded", + "op_type": "cuvs::distance::detail::ops::hamming_distance_op", + "header_file": "distance/detail/distance_ops/hamming.cuh" + }, + { + "distance_name": "hellinger_expanded", + "distance_abbrev": "hellinger_expanded", + "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", + "header_file": "distance/detail/distance_ops/hellinger.cuh" + }, + { + "distance_name": "jensen_shannon", + "distance_abbrev": "jensen_shannon", + "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", + "header_file": "distance/detail/distance_ops/jensen_shannon.cuh" + }, + { + "distance_name": "kl_divergence", + "distance_abbrev": "kl_divergence", + "op_type": "cuvs::distance::detail::ops::kl_divergence_op", + "header_file": "distance/detail/distance_ops/kl_divergence.cuh" + }, + { + "distance_name": "l1", + "distance_abbrev": "l1", + "op_type": "cuvs::distance::detail::ops::l1_distance_op", + "header_file": "distance/detail/distance_ops/l1.cuh" + }, + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + }, + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + }, + { + "distance_name": "l_inf", + "distance_abbrev": "l_inf", + "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", + "header_file": "distance/detail/distance_ops/l_inf.cuh" + }, + { + "distance_name": "lp_unexpanded", + "distance_abbrev": "lp_unexpanded", + "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", + "header_file": "distance/detail/distance_ops/lp_unexp.cuh" + }, + { + "distance_name": "russel_rao", + "distance_abbrev": "russel_rao", + "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", + "header_file": "distance/detail/distance_ops/russel_rao.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "data_type": "double", + "data_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + }, + { + "data_type": "half", + "data_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + } + ], + "_index": [ + { + "index_type": "int", + "index_abbrev": "i" + } + ] + }, + { + "_distance": [ + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "data_type": "double", + "data_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] + } +] diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json new file mode 100644 index 0000000000..69959a8326 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json @@ -0,0 +1,128 @@ +{ + "_distance": [ + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "data_type": "double", + "data_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + }, + { + "data_type": "half", + "data_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] +} diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in new file mode 100644 index 0000000000..3a79c66480 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in @@ -0,0 +1,35 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include <@header_file@> +#include + +namespace { + +// clang-format off +using data_t = @data_type@; +using acc_t = @acc_type@; +using index_t = @index_type@; +using op_t = @op_type@; +// clang-format on + +} // namespace + +namespace cuvs::distance::detail { + +template <> +__device__ void compute_distance(op_t distance_op, + acc_t& acc, + data_t x, + data_t y) +{ + distance_op.core(acc, x, y); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json new file mode 100644 index 0000000000..2dbdd12efa --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json @@ -0,0 +1,155 @@ +[ + { + "_distance": [ + { + "distance_name": "canberra", + "distance_abbrev": "canberra", + "op_type": "cuvs::distance::detail::ops::canberra_distance_op", + "header_file": "distance/detail/distance_ops/canberra.cuh" + }, + { + "distance_name": "correlation", + "distance_abbrev": "correlation", + "op_type": "cuvs::distance::detail::ops::correlation_distance_op", + "header_file": "distance/detail/distance_ops/correlation.cuh" + }, + { + "distance_name": "cosine", + "distance_abbrev": "cosine", + "op_type": "cuvs::distance::detail::ops::cosine_distance_op", + "header_file": "distance/detail/distance_ops/cosine.cuh" + }, + { + "distance_name": "hamming_unexpanded", + "distance_abbrev": "hamming_unexpanded", + "op_type": "cuvs::distance::detail::ops::hamming_distance_op", + "header_file": "distance/detail/distance_ops/hamming.cuh" + }, + { + "distance_name": "hellinger_expanded", + "distance_abbrev": "hellinger_expanded", + "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", + "header_file": "distance/detail/distance_ops/hellinger.cuh" + }, + { + "distance_name": "jensen_shannon", + "distance_abbrev": "jensen_shannon", + "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", + "header_file": "distance/detail/distance_ops/jensen_shannon.cuh" + }, + { + "distance_name": "kl_divergence", + "distance_abbrev": "kl_divergence", + "op_type": "cuvs::distance::detail::ops::kl_divergence_op", + "header_file": "distance/detail/distance_ops/kl_divergence.cuh" + }, + { + "distance_name": "l1", + "distance_abbrev": "l1", + "op_type": "cuvs::distance::detail::ops::l1_distance_op", + "header_file": "distance/detail/distance_ops/l1.cuh" + }, + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + }, + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + }, + { + "distance_name": "l_inf", + "distance_abbrev": "l_inf", + "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", + "header_file": "distance/detail/distance_ops/l_inf.cuh" + }, + { + "distance_name": "lp_unexpanded", + "distance_abbrev": "lp_unexpanded", + "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", + "header_file": "distance/detail/distance_ops/lp_unexp.cuh" + }, + { + "distance_name": "russel_rao", + "distance_abbrev": "russel_rao", + "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", + "header_file": "distance/detail/distance_ops/russel_rao.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + }, + { + "data_type": "half", + "data_abbrev": "h", + "type_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + } + ], + "_index": [ + { + "index_type": "int", + "index_abbrev": "i" + } + ] + }, + { + "_distance": [ + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] + } +] diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json new file mode 100644 index 0000000000..d140025078 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json @@ -0,0 +1,45 @@ +{ + "_distance": [ + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + }, + { + "data_type": "half", + "data_abbrev": "h", + "type_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] +} diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh new file mode 100644 index 0000000000..ff6479a05f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh @@ -0,0 +1,21 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::distance::detail { + +template +extern __device__ void compute_distance(OpT distance_op, AccT& acc, DataT x, DataT y); + +template +extern __device__ void compute_distance_epilog(OpT distance_op, + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + AccT* regxn, + AccT* regyn, + IdxT grid_stride_x, + IdxT grid_stride_y); + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh new file mode 100644 index 0000000000..5941e7f247 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "pairwise_matrix_planner.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace cuvs::distance::detail { + +template +inline constexpr bool pairwise_matrix_jit_always_false_v = false; + +template +constexpr auto get_pairwise_scalar_type_tag() +{ + if constexpr (std::is_same_v) { + return tag_f{}; + } else if constexpr (std::is_same_v) { + return tag_d{}; + } else if constexpr (std::is_same_v || std::is_same_v) { + return tag_h{}; + } else { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a scalar tag for this type"); + } +} + +template +constexpr auto get_pairwise_index_type_tag() +{ + if constexpr (std::is_same_v) { + return tag_index_i{}; + } else if constexpr (std::is_same_v) { + return tag_index_i64{}; + } else { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have an index tag for this type"); + } +} + +template +using pairwise_layout_tag_t = std::conditional_t; + +template +struct pairwise_fin_op_tag { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a final-op tag for this type"); +}; + +template <> +struct pairwise_fin_op_tag { + using type = tag_fin_op_identity; +}; + +template +struct pairwise_fin_op_tag> { + using type = tag_fin_op_rbf; +}; + +template +using pairwise_fin_op_tag_t = typename pairwise_fin_op_tag::type; + +template +struct pairwise_distance_op_tag { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a distance-op tag for this type"); +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_canberra; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_correlation; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_cosine; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_hamming_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_hellinger_expanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_jensen_shannon; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_kl_divergence; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l1; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l2_expanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l2_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l_inf; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_lp_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_russel_rao; +}; + +template +using pairwise_distance_op_tag_t = typename pairwise_distance_op_tag::type; + +template +using pairwise_matrix_jit_kernel_t = void(OpT, pairwise_matrix_params); + +template +void pairwise_matrix_jit_dispatch(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ + using AccT = typename OpT::AccT; + + int vec_len = determine_vec_len(params); + + auto launch = [&](auto row_major, auto vec_len_aligned) { + constexpr int vec_len_op = OpT::expensive_inner_loop ? 1 : vec_len_aligned(); + constexpr int veclen = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = std::conditional_t; + + using DistanceTag = pairwise_distance_op_tag_t; + using DataTag = decltype(get_pairwise_scalar_type_tag()); + using AccTag = decltype(get_pairwise_scalar_type_tag()); + using OutTag = decltype(get_pairwise_scalar_type_tag()); + using IndexTag = decltype(get_pairwise_index_type_tag()); + using FinOpTag = pairwise_fin_op_tag_t; + using LayoutTag = pairwise_layout_tag_t; + + PairwiseMatrixPlanner planner; + planner.add_entrypoint(); + planner.add_compute_distance_function(); + planner.add_compute_distance_epilog_function(); + + auto launcher = planner.get_launcher(); + + dim3 block(Policy::Nthreads); + int smem_size = OpT::template shared_mem_size(); + dim3 grid = + launchConfigGenerator(params.m, params.n, smem_size, launcher->get_kernel()); + + launcher->dispatch>( + stream, grid, block, smem_size, distance_op, params); + }; + + dispatch_layout(params.is_row_major, vec_len, launch); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in new file mode 100644 index 0000000000..976c108758 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include <@header_file@> +#include +#include +#include +#include + +namespace { + +// clang-format off +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = raft::identity_op; +using op_t = @op_type@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; +constexpr bool row_major = std::is_same_v; +using base_t = raft::linalg::Contractions_NT; +// clang-format on + +} // namespace + +namespace cuvs::distance::detail { + +extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_matrix_kernel( + op_t distance_op, pairwise_matrix_params params) +{ + extern __shared__ char smem[]; + + auto epilog_op = raft::void_op(); + auto row_epilog_op = raft::void_op(); + + constexpr bool write_out = true; + constexpr bool use_jit_device_functions = true; + PairwiseDistances + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, + smem, + distance_op, + epilog_op, + params.fin_op, + row_epilog_op); + obj.run(); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp new file mode 100644 index 0000000000..6d8191de75 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include "registration_tags.hpp" + +#include +#include + +namespace cuvs::distance::detail { + +inline constexpr char kPairwiseMatrixJitEntrypoint[] = "pairwise_matrix_kernel"; + +struct PairwiseMatrixPlanner : AlgorithmPlanner { + inline static LauncherJitCache launcher_jit_cache{}; + + PairwiseMatrixPlanner() : AlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} + + explicit PairwiseMatrixPlanner(std::string entrypoint) + : AlgorithmPlanner(std::move(entrypoint), launcher_jit_cache) + { + } + + template + void add_entrypoint() + { + this->add_static_fragment>(); + } + + template + void add_compute_distance_function() + { + this->add_static_fragment< + fragment_tag_compute_distance>(); + } + + template + void add_compute_distance_epilog_function() + { + this->add_static_fragment>(); + } +}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in new file mode 100644 index 0000000000..c112594b85 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in @@ -0,0 +1,78 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include <@header_file@> +#include +#include +#include +#include +#include + +namespace { + +// clang-format off +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = cuvs::distance::kernels::rbf_fin_op; +using op_t = @op_type@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; +constexpr bool row_major = std::is_same_v; +using base_t = raft::linalg::Contractions_NT; +// clang-format on + +} // namespace + +namespace cuvs::distance::detail { + +extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_matrix_kernel( + op_t distance_op, pairwise_matrix_params params) +{ + extern __shared__ char smem[]; + + auto epilog_op = raft::void_op(); + auto row_epilog_op = raft::void_op(); + + constexpr bool write_out = true; + constexpr bool use_jit_device_functions = true; + PairwiseDistances + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, + smem, + distance_op, + epilog_op, + params.fin_op, + row_epilog_op); + obj.run(); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp new file mode 100644 index 0000000000..3e56bcc5b3 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::distance::detail { + +struct tag_f {}; +struct tag_h {}; +struct tag_d {}; + +struct tag_index_i {}; +struct tag_index_i64 {}; + +struct tag_layout_row {}; +struct tag_layout_col {}; + +struct tag_fin_op_identity {}; +struct tag_fin_op_rbf {}; + +struct tag_distance_canberra {}; +struct tag_distance_correlation {}; +struct tag_distance_cosine {}; +struct tag_distance_hamming_unexpanded {}; +struct tag_distance_hellinger_expanded {}; +struct tag_distance_jensen_shannon {}; +struct tag_distance_kl_divergence {}; +struct tag_distance_l1 {}; +struct tag_distance_l2_expanded {}; +struct tag_distance_l2_unexpanded {}; +struct tag_distance_l_inf {}; +struct tag_distance_lp_unexpanded {}; +struct tag_distance_russel_rao {}; + +template +struct fragment_tag_pairwise_matrix {}; + +template +struct fragment_tag_compute_distance {}; + +template +struct fragment_tag_compute_distance_epilog {}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh deleted file mode 100644 index e7c9facef6..0000000000 --- a/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ /dev/null @@ -1,144 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "../pairwise_distance_base.cuh" // PairwiseDistances -#include "params.cuh" // pairwise_matrix_params -#include // raft::void_op -#include // raft::util::arch::SM_compute_arch - -#include // assert - -namespace cuvs::distance::detail { - -template -__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL - pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) -{ - // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { - assert(false); - return; - } - - extern __shared__ char smem[]; - - // The epilog is already provided by distance_op. Do not provide additional - // epilogs. - auto epilog_op = raft::void_op(); - // No support for row_epilog_op. - auto row_epilog_op = raft::void_op(); - - // Always write output - constexpr bool write_out = true; - PairwiseDistances - obj(params.x, - params.y, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.x_norm, - params.y_norm, - params.out, - smem, - distance_op, - epilog_op, - params.fin_op, - row_epilog_op); - obj.run(); -} - -// The type of a pointer to the pairwise matrix kernel. The following template -// arguments are type-erased: -// -// - The kernel policy -// - row_major -// - SM_compat_t -template -using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); - -// A wrapper for the pairwise matrix kernel launch. Includes kernel launch -// parameters. -template -struct pairwise_matrix_sm60_wrapper { - dim3 grid; - dim3 block; - int smem_size; - pairwise_matrix_kernel_t kernel_ptr; - - void launch(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) - { - kernel_ptr<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -/** @brief: Create kernel launch wrapper for pairwise matrix kernel - * - * This can be used to type-erase the kernel execution policy, row_major, and SM - * compatibility range. - * - * @tparam Policy: Kernel execution policy - * @tparam row_major: Indicates whether input matrices are row major - * @tparam OpT: Type of distance operation - * @tparam IdxT: Index type - * @tparam DataT: Data type - * @tparam OutT: Output data type - * @tparam FinOpT: Final operation type - * @tparam SM_compat_t: Type of the SM architecture compatibility - * - * @param distance_op: Distance operation - * @param params: Parameters - * @param sm_compat_range: Which SM architectures to compile for. - */ -template -pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - dim3 block(Policy::Nthreads); - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = OpT::template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - return pairwise_matrix_sm60_wrapper{ - grid, block, smem_size, kernel}; -} - -}; // namespace cuvs::distance::detail From cf6ccb074cd3ab5e16bccf8da9426b3d624ab2e3 Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Mon, 18 May 2026 14:31:25 -0700 Subject: [PATCH 02/18] Update cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in Co-authored-by: Kyle Edwards --- .../jit_lto_kernels/compute_distance_epilog_kernel.cu.in | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in index d0b4d64645..19810611df 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in @@ -13,13 +13,12 @@ namespace { -// clang-format off using data_t = @data_type@; using acc_t = @acc_type@; using index_t = @index_type@; using op_t = @op_type@; -using policy_t = typename raft::linalg::Policy4x4::@policy_type@; -// clang-format on +constexpr int veclen = @veclen@; +using policy_t = typename raft::linalg::Policy4x4::@policy_type@; } // namespace From f9ff9e83600a258f2a2e8c25e5ec7c8b7f8b6763 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Mon, 18 May 2026 17:51:06 -0700 Subject: [PATCH 03/18] update arch_include --- .../detail/pairwise_matrix/dispatch-inl.cuh | 79 +++++++++++++++++-- .../pairwise_matrix/dispatch_matrix.json | 6 +- 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index 9cc9f29710..6bb342733b 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -4,16 +4,62 @@ */ #pragma once -/* This file provides concise function templates that can be instantiated in - * src/distance/detail/pairwise_matrix/. Previously, - * cuvs::distance::detail::distance was instantiated. The function necessarily - * required a large set of include files, which slowed down the build. +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the JIT implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * cuvs::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. */ +#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op #include "../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh" // pairwise_matrix_jit_dispatch +#include // RAFT_CUDA_TRY +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include +// dispatch_sm80.cuh for CUTLASS-backed distance ops, as is done in +// src/distance/detail/pairwise_matrix/dispatch_*.cu. namespace cuvs::distance::detail { +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. +template +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); + +inline auto current_device_arch() +{ + int device = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + int major = 0; + int minor = 0; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + return raft::util::arch::SM(major * 10 + minor); +} + template ::value; + + if constexpr (cutlass_op_unavailable) { + pairwise_matrix_jit_dispatch(distance_op, params, stream); + } else { + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto runtime_arch = current_device_arch(); + + // TODO: CUTLASS does not support odd `k` with half DataT. + bool unsupported_half = (sizeof(DataT) == 2) && ((k % 2) != 0); + + if (!unsupported_half && cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + return; + } + + pairwise_matrix_jit_dispatch(distance_op, params, stream); + } } }; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json index ebafa5c245..bf6be6bed6 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json @@ -40,7 +40,7 @@ { "op_type": "cuvs::distance::detail::ops::cosine_distance_op", "op_abbrev": "cosine", - "arch_includes": "" + "arch_includes": "#include " }, { "op_type": "cuvs::distance::detail::ops::hamming_distance_op", @@ -70,7 +70,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "" + "arch_includes": "#include " }, { "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", @@ -119,7 +119,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "" + "arch_includes": "#include " } ], "index_type": "int64_t", From 7dce6e3074b314393ba7a6d841de98ae0dd03628 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Mon, 18 May 2026 18:05:50 -0700 Subject: [PATCH 04/18] correct tags --- cpp/CMakeLists.txt | 12 ++++++------ cpp/include/cuvs/detail/jit_lto/common_fragments.hpp | 2 ++ .../pairwise_matrix/pairwise_matrix_fragments.hpp} | 6 ------ .../jit_lto_kernels/pairwise_matrix_planner.hpp | 3 +-- 4 files changed, 9 insertions(+), 14 deletions(-) rename cpp/{src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp => include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp} (93%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e67d569115..3aa700ccbe 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -781,7 +781,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -794,7 +794,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -808,7 +808,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -822,7 +822,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -836,7 +836,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -850,7 +850,7 @@ if(NOT BUILD_CPU_ONLY) FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES - "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index cb33e4109b..c3b66d7c06 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -9,6 +9,7 @@ namespace cuvs::neighbors::detail { struct tag_f {}; struct tag_h {}; +struct tag_d {}; struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; @@ -16,6 +17,7 @@ struct tag_filter_bitset {}; struct tag_bitset_u32 {}; +struct tag_index_i {}; struct tag_index_u32 {}; struct tag_index_i64 {}; diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp similarity index 93% rename from cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp rename to cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp index 3e56bcc5b3..c52fd3e4ef 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/registration_tags.hpp +++ b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp @@ -7,12 +7,6 @@ namespace cuvs::distance::detail { -struct tag_f {}; -struct tag_h {}; -struct tag_d {}; - -struct tag_index_i {}; -struct tag_index_i64 {}; struct tag_layout_row {}; struct tag_layout_col {}; diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp index 6d8191de75..037fe9e8e8 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -6,8 +6,7 @@ #pragma once #include - -#include "registration_tags.hpp" +#include #include #include From e1f6a5c6e8e5f74ed0b990e54e9268c53eb212eb Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Mon, 18 May 2026 18:25:20 -0700 Subject: [PATCH 05/18] update template --- .../jit_lto_kernels/pairwise_matrix_jit.cuh | 28 ++++++------- .../pairwise_matrix_planner.hpp | 41 ++++++++----------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh index 5941e7f247..e522049916 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh @@ -179,22 +179,18 @@ void pairwise_matrix_jit_dispatch(OpT distance_op, using FinOpTag = pairwise_fin_op_tag_t; using LayoutTag = pairwise_layout_tag_t; - PairwiseMatrixPlanner planner; - planner.add_entrypoint(); - planner.add_compute_distance_function(); - planner.add_compute_distance_epilog_function(); + PairwiseMatrixPlanner + planner; + planner.add_entrypoint(); + planner.add_compute_distance_function(); + planner.add_compute_distance_epilog_function(); auto launcher = planner.get_launcher(); diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp index 037fe9e8e8..0d00b3eca6 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -8,31 +8,33 @@ #include #include -#include -#include - namespace cuvs::distance::detail { inline constexpr char kPairwiseMatrixJitEntrypoint[] = "pairwise_matrix_kernel"; +template struct PairwiseMatrixPlanner : AlgorithmPlanner { + using DistanceTag = DistanceTag_; + using DataTag = DataTag_; + using AccTag = AccTag_; + using OutTag = OutTag_; + using IndexTag = IndexTag_; + using FinOpTag = FinOpTag_; + using LayoutTag = LayoutTag_; + + static constexpr int Veclen = Veclen_; + inline static LauncherJitCache launcher_jit_cache{}; PairwiseMatrixPlanner() : AlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} - explicit PairwiseMatrixPlanner(std::string entrypoint) - : AlgorithmPlanner(std::move(entrypoint), launcher_jit_cache) - { - } - - template void add_entrypoint() { this->add_static_fragment>(); } - template void add_compute_distance_function() { this->add_static_fragment< fragment_tag_compute_distance>(); } - template void add_compute_distance_epilog_function() { this->add_static_fragment Date: Mon, 18 May 2026 18:50:06 -0700 Subject: [PATCH 06/18] reapply cmake comment --- cpp/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3aa700ccbe..c326262cf2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -857,6 +857,9 @@ if(NOT BUILD_CPU_ONLY) ) endblock() + # Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a + # way to do an item-wise transform on a list after computing the matrix product and before + # configuring the file generate_inst_matrix( pairwise_matrix_dispatch_inst_files MATRIX_JSON_FILE From 8fcb1bbe6b75771b19b44289675eb29bdcc107f1 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Mon, 18 May 2026 19:11:40 -0700 Subject: [PATCH 07/18] rm jit boolean --- .../detail/pairwise_distance_base.cuh | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index 4300901179..b3340c393a 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -55,8 +55,7 @@ template , - bool useJitDeviceFunctions = false> + typename BaseClass = raft::linalg::Contractions_NT> struct PairwiseDistances : public BaseClass { // Get accumulation type from distance_op using AccT = typename OpT::AccT; @@ -152,12 +151,8 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - if constexpr (useJitDeviceFunctions) { - compute_distance_epilog( - distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } + compute_distance_epilog( + distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m); // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { @@ -166,12 +161,8 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - if constexpr (useJitDeviceFunctions) { - compute_distance_epilog( - distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - } else { - distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - } + compute_distance_epilog( + distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m); // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } @@ -215,12 +206,8 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - if constexpr (useJitDeviceFunctions) { - compute_distance( - distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]); - } else { - distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); - } + compute_distance( + distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]); } } } From 4ae75070919674ce6ba03433e7e9334e2083d331 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 11:43:38 -0700 Subject: [PATCH 08/18] address pr reviews --- cpp/CMakeLists.txt | 18 ++++++++++++------ .../cuvs/detail/jit_lto/common_fragments.hpp | 2 +- cpp/src/distance/detail/distance_ops/l_inf.cuh | 1 + .../jit_lto_kernels/pairwise_matrix_jit.cuh | 11 ++++++----- .../pairwise_matrix_kernel.cu.in | 9 +++------ .../pairwise_matrix_rbf_kernel.cu.in | 9 +++------ 6 files changed, 26 insertions(+), 24 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c326262cf2..9df3f4294a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -779,9 +779,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -792,9 +793,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_rbf_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_@out_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -806,9 +808,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -820,9 +823,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_rbf_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@>" + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -834,9 +838,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -848,9 +853,10 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${distance_ns}::tag_@data_abbrev@, ${distance_ns}::tag_@acc_abbrev@, ${distance_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index c3b66d7c06..cbd3f72730 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -17,7 +17,7 @@ struct tag_filter_bitset {}; struct tag_bitset_u32 {}; -struct tag_index_i {}; +struct tag_index_i32 {}; struct tag_index_u32 {}; struct tag_index_i64 {}; diff --git a/cpp/src/distance/detail/distance_ops/l_inf.cuh b/cpp/src/distance/detail/distance_ops/l_inf.cuh index 2b2f09e8d6..d78abcf524 100644 --- a/cpp/src/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/src/distance/detail/distance_ops/l_inf.cuh @@ -5,6 +5,7 @@ #pragma once +#include // raft::abs, raft::max #include // DI namespace cuvs::distance::detail::ops { diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh index e522049916..5ec296e90a 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh @@ -7,6 +7,7 @@ #include "pairwise_matrix_planner.hpp" +#include #include #include #include @@ -31,11 +32,11 @@ template constexpr auto get_pairwise_scalar_type_tag() { if constexpr (std::is_same_v) { - return tag_f{}; + return cuvs::neighbors::detail::tag_f{}; } else if constexpr (std::is_same_v) { - return tag_d{}; + return cuvs::neighbors::detail::tag_d{}; } else if constexpr (std::is_same_v || std::is_same_v) { - return tag_h{}; + return cuvs::neighbors::detail::tag_h{}; } else { static_assert(pairwise_matrix_jit_always_false_v, "Pairwise matrix JIT LTO does not have a scalar tag for this type"); @@ -46,9 +47,9 @@ template constexpr auto get_pairwise_index_type_tag() { if constexpr (std::is_same_v) { - return tag_index_i{}; + return cuvs::neighbors::detail::tag_index_i32{}; } else if constexpr (std::is_same_v) { - return tag_index_i64{}; + return cuvs::neighbors::detail::tag_index_i64{}; } else { static_assert(pairwise_matrix_jit_always_false_v, "Pairwise matrix JIT LTO does not have an index tag for this type"); diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in index 976c108758..af2b442e06 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in @@ -16,18 +16,17 @@ namespace { -// clang-format off using data_t = @data_type@; using acc_t = @acc_type@; using out_t = @out_type@; using index_t = @index_type@; using fin_op_t = raft::identity_op; using op_t = @op_type@; -using policies_t = raft::linalg::Policy4x4; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; using policy_t = typename policies_t::@policy_type@; constexpr bool row_major = std::is_same_v; using base_t = raft::linalg::Contractions_NT; -// clang-format on } // namespace @@ -42,7 +41,6 @@ extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_mat auto row_epilog_op = raft::void_op(); constexpr bool write_out = true; - constexpr bool use_jit_device_functions = true; PairwiseDistances + base_t> obj(params.x, params.y, params.m, diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in index c112594b85..fb33f842d9 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in @@ -17,18 +17,17 @@ namespace { -// clang-format off using data_t = @data_type@; using acc_t = @acc_type@; using out_t = @out_type@; using index_t = @index_type@; using fin_op_t = cuvs::distance::kernels::rbf_fin_op; using op_t = @op_type@; -using policies_t = raft::linalg::Policy4x4; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; using policy_t = typename policies_t::@policy_type@; constexpr bool row_major = std::is_same_v; using base_t = raft::linalg::Contractions_NT; -// clang-format on } // namespace @@ -43,7 +42,6 @@ extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_mat auto row_epilog_op = raft::void_op(); constexpr bool write_out = true; - constexpr bool use_jit_device_functions = true; PairwiseDistances + base_t> obj(params.x, params.y, params.m, From e465980cbf1a18a27482b82d19cfe19b9cf34db6 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 11:51:41 -0700 Subject: [PATCH 09/18] update index type in json --- .../jit_lto_kernels/compute_distance_epilog_matrix.json | 2 +- .../jit_lto_kernels/compute_distance_matrix.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json index afe37ce712..dfd88f39cd 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json @@ -195,7 +195,7 @@ "_index": [ { "index_type": "int", - "index_abbrev": "i" + "index_abbrev": "i32" } ] }, diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json index 2dbdd12efa..0c3b0c4164 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json @@ -112,7 +112,7 @@ "_index": [ { "index_type": "int", - "index_abbrev": "i" + "index_abbrev": "i32" } ] }, From 7df2f95924b7fc016bd923a8c37cb5712aed1c49 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 11:55:43 -0700 Subject: [PATCH 10/18] do not switch off clang --- .../jit_lto_kernels/compute_distance_kernel.cu.in | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in index 3a79c66480..4d39801d40 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in @@ -12,12 +12,10 @@ namespace { -// clang-format off using data_t = @data_type@; using acc_t = @acc_type@; using index_t = @index_type@; using op_t = @op_type@; -// clang-format on } // namespace From 15bb587cf27360f553eed434b6c3cbb018063646 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 12:38:26 -0700 Subject: [PATCH 11/18] style check and header includes --- cpp/CMakeLists.txt | 30 ++++++++----------- .../pairwise_matrix_fragments.hpp | 1 - cpp/src/distance/detail/distance_ops/l1.cuh | 3 +- .../distance/detail/distance_ops/l_inf.cuh | 2 +- .../detail/pairwise_distance_base.cuh | 2 +- .../detail/pairwise_matrix/dispatch-inl.cuh | 8 ++--- .../compute_distance_epilog_kernel.cu.in | 10 +++---- .../pairwise_matrix_kernel.cu.in | 20 ++++++------- .../pairwise_matrix_rbf_kernel.cu.in | 22 +++++++------- 9 files changed, 46 insertions(+), 52 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9df3f4294a..2e41f684a0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -780,9 +780,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -794,9 +793,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_rbf_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -809,9 +807,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -824,9 +821,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -839,9 +835,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -854,9 +849,8 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" FRAGMENT_TAG_FORMAT "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES - "" - "" + FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements diff --git a/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp index c52fd3e4ef..d9b7f4ec8d 100644 --- a/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp @@ -7,7 +7,6 @@ namespace cuvs::distance::detail { - struct tag_layout_row {}; struct tag_layout_col {}; diff --git a/cpp/src/distance/detail/distance_ops/l1.cuh b/cpp/src/distance/detail/distance_ops/l1.cuh index 7c0cbe0b6c..09b4d7b214 100644 --- a/cpp/src/distance/detail/distance_ops/l1.cuh +++ b/cpp/src/distance/detail/distance_ops/l1.cuh @@ -1,9 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::abs #include // DI namespace cuvs::distance::detail::ops { diff --git a/cpp/src/distance/detail/distance_ops/l_inf.cuh b/cpp/src/distance/detail/distance_ops/l_inf.cuh index d78abcf524..03982183af 100644 --- a/cpp/src/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/src/distance/detail/distance_ops/l_inf.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index b3340c393a..f0b3c5a452 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index 6bb342733b..02350087db 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -19,10 +19,10 @@ * build. */ -#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op +#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op #include "../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh" // pairwise_matrix_jit_dispatch -#include // RAFT_CUDA_TRY -#include // raft::util::arch::SM_* +#include // RAFT_CUDA_TRY +#include // raft::util::arch::SM_* // NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. // Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in index 19810611df..ffb3dc8c84 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in @@ -13,12 +13,12 @@ namespace { -using data_t = @data_type@; -using acc_t = @acc_type@; -using index_t = @index_type@; -using op_t = @op_type@; +using data_t = @data_type@; +using acc_t = @acc_type@; +using index_t = @index_type@; +using op_t = @op_type@; constexpr int veclen = @veclen@; -using policy_t = typename raft::linalg::Policy4x4::@policy_type@; +using policy_t = typename raft::linalg::Policy4x4::@policy_type@; } // namespace diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in index af2b442e06..cf60efce73 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in @@ -16,17 +16,17 @@ namespace { -using data_t = @data_type@; -using acc_t = @acc_type@; -using out_t = @out_type@; -using index_t = @index_type@; -using fin_op_t = raft::identity_op; -using op_t = @op_type@; -constexpr int veclen = @veclen@; -using policies_t = raft::linalg::Policy4x4; -using policy_t = typename policies_t::@policy_type@; +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = raft::identity_op; +using op_t = @op_type@; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; constexpr bool row_major = std::is_same_v; -using base_t = raft::linalg::Contractions_NT; +using base_t = raft::linalg::Contractions_NT; } // namespace diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in index fb33f842d9..8b2f532c23 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in @@ -17,17 +17,17 @@ namespace { -using data_t = @data_type@; -using acc_t = @acc_type@; -using out_t = @out_type@; -using index_t = @index_type@; -using fin_op_t = cuvs::distance::kernels::rbf_fin_op; -using op_t = @op_type@; -constexpr int veclen = @veclen@; -using policies_t = raft::linalg::Policy4x4; -using policy_t = typename policies_t::@policy_type@; +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = cuvs::distance::kernels::rbf_fin_op; +using op_t = @op_type@; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; constexpr bool row_major = std::is_same_v; -using base_t = raft::linalg::Contractions_NT; +using base_t = raft::linalg::Contractions_NT; } // namespace @@ -41,7 +41,7 @@ extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_mat auto epilog_op = raft::void_op(); auto row_epilog_op = raft::void_op(); - constexpr bool write_out = true; + constexpr bool write_out = true; PairwiseDistances Date: Tue, 19 May 2026 12:39:26 -0700 Subject: [PATCH 12/18] style check and header includes --- cpp/src/distance/detail/distance_ops/l2_unexp.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh index c7094cc1b8..49b8ad23d0 100644 --- a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::sqrt #include // DI #include From 5bb0293da47b49d5e7647782b4a4af7465e95217 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 14:19:17 -0700 Subject: [PATCH 13/18] compilation errors --- .../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh | 2 +- cpp/src/distance/detail/sparse/l2_distance.cuh | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh index 5ec296e90a..17526e75b7 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh @@ -200,7 +200,7 @@ void pairwise_matrix_jit_dispatch(OpT distance_op, dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, launcher->get_kernel()); - launcher->dispatch>( + launcher->template dispatch>( stream, grid, block, smem_size, distance_op, params); }; diff --git a/cpp/src/distance/detail/sparse/l2_distance.cuh b/cpp/src/distance/detail/sparse/l2_distance.cuh index 89d7c978a4..3b06dcb895 100644 --- a/cpp/src/distance/detail/sparse/l2_distance.cuh +++ b/cpp/src/distance/detail/sparse/l2_distance.cuh @@ -9,6 +9,7 @@ #include "ip_distance.cuh" #include +#include // raft::sqrt #include #include #include From 5da125a5173062f4746cbe4b06ab25c7a2b520d5 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 14:28:27 -0700 Subject: [PATCH 14/18] fix header includes --- cpp/src/distance/detail/distance_ops/correlation.cuh | 1 + cpp/src/distance/detail/distance_ops/hellinger.cuh | 1 + 2 files changed, 2 insertions(+) diff --git a/cpp/src/distance/detail/distance_ops/correlation.cuh b/cpp/src/distance/detail/distance_ops/correlation.cuh index 39cfe4b8d2..a1fee4f792 100644 --- a/cpp/src/distance/detail/distance_ops/correlation.cuh +++ b/cpp/src/distance/detail/distance_ops/correlation.cuh @@ -5,6 +5,7 @@ #pragma once +#include // raft::sqrt #include // DI #include diff --git a/cpp/src/distance/detail/distance_ops/hellinger.cuh b/cpp/src/distance/detail/distance_ops/hellinger.cuh index 44e3e83375..83b0389d30 100644 --- a/cpp/src/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/src/distance/detail/distance_ops/hellinger.cuh @@ -4,6 +4,7 @@ */ #pragma once +#include // raft::sqrt #include // DI namespace cuvs::distance::detail::ops { From d5e942df8cfd253779c51a47e4c7d2d58b8c48c8 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 15:52:38 -0700 Subject: [PATCH 15/18] restore old version of dispatch-inl --- .../detail/pairwise_matrix/dispatch-inl.cuh | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index 02350087db..ce00a809c5 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -21,7 +21,6 @@ #include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op #include "../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh" // pairwise_matrix_jit_dispatch -#include // RAFT_CUDA_TRY #include // raft::util::arch::SM_* // NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. @@ -48,16 +47,14 @@ void pairwise_matrix_sm80_dispatch(OpT, SM_compat_t, cudaStream_t); -inline auto current_device_arch() +template +__global__ void pairwise_matrix_arch_probe_kernel() { - int device = 0; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - int major = 0; - int minor = 0; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - return raft::util::arch::SM(major * 10 + minor); } template ; + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); // TODO: CUTLASS does not support odd `k` with half DataT. bool unsupported_half = (sizeof(DataT) == 2) && ((k % 2) != 0); From 60d3344b50eb70fd935c7ff9ba90ebf974d75158 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 19 May 2026 15:58:06 -0700 Subject: [PATCH 16/18] style and docs --- cpp/src/distance/detail/distance_ops/correlation.cuh | 2 +- cpp/src/distance/detail/distance_ops/hellinger.cuh | 2 +- cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/distance/detail/distance_ops/correlation.cuh b/cpp/src/distance/detail/distance_ops/correlation.cuh index a1fee4f792..7276845ddf 100644 --- a/cpp/src/distance/detail/distance_ops/correlation.cuh +++ b/cpp/src/distance/detail/distance_ops/correlation.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/distance/detail/distance_ops/hellinger.cuh b/cpp/src/distance/detail/distance_ops/hellinger.cuh index 83b0389d30..35d8034c30 100644 --- a/cpp/src/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/src/distance/detail/distance_ops/hellinger.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index ce00a809c5..aff78d87f9 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -47,6 +47,9 @@ void pairwise_matrix_sm80_dispatch(OpT, SM_compat_t, cudaStream_t); +// This kernel is never launched. It only gives arch::kernel_virtual_arch a static kernel pointer +// from this TU/fatbin, without forcing JIT compilation just to decide whether the CUTLASS path is +// usable. template Date: Tue, 19 May 2026 19:46:06 -0700 Subject: [PATCH 17/18] add the ifdef --- cpp/src/distance/detail/pairwise_distance_base.cuh | 14 ++++++++++++++ .../jit_lto_kernels/pairwise_matrix_kernel.cu.in | 2 ++ .../pairwise_matrix_rbf_kernel.cu.in | 2 ++ 3 files changed, 18 insertions(+) diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index f0b3c5a452..a6fea8a017 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT #include "pairwise_matrix/jit_lto_kernels/device_functions.cuh" +#endif #include // raft::linalg::Contractions_NT #include // ceildiv #include // RAFT_CUDA_TRY @@ -151,8 +153,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT compute_distance_epilog( distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m); +#else + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); +#endif // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { @@ -161,8 +167,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT compute_distance_epilog( distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m); +#else + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); +#endif // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } @@ -206,8 +216,12 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT compute_distance( distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]); +#else + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); +#endif } } } diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in index cf60efce73..5432afb5c4 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in @@ -9,7 +9,9 @@ #include #include <@header_file@> +#define CUVS_DISTANCE_PAIRWISE_USE_JIT #include +#undef CUVS_DISTANCE_PAIRWISE_USE_JIT #include #include #include diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in index 8b2f532c23..30baae0dc3 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in @@ -10,7 +10,9 @@ #include <@header_file@> #include +#define CUVS_DISTANCE_PAIRWISE_USE_JIT #include +#undef CUVS_DISTANCE_PAIRWISE_USE_JIT #include #include #include From 884caf2afc7bc5a67d4b2bd8d01d75face05091c Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Fri, 5 Jun 2026 20:08:36 -0700 Subject: [PATCH 18/18] simplify json --- .../compute_distance_epilog_matrix.json | 66 ++++++------------- .../compute_distance_epilog_rbf_matrix.json | 58 +++++----------- 2 files changed, 34 insertions(+), 90 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json index dfd88f39cd..3c03a381f7 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json @@ -80,10 +80,18 @@ "header_file": "distance/detail/distance_ops/russel_rao.cuh" } ], - "_data": [ + "_acc_out": [ { - "data_type": "float", - "data_abbrev": "f", + "_data_input": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "half", + "data_abbrev": "h" + } + ], "acc_type": "float", "acc_abbrev": "f", "out_type": "float", @@ -122,42 +130,16 @@ ] }, { - "data_type": "double", - "data_abbrev": "d", "acc_type": "double", "acc_abbrev": "d", "out_type": "double", "out_abbrev": "d", - "_policy": [ + "_data_input": [ { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "1" - }, - { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "1" - }, - { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "2" - }, - { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "2" + "data_type": "double", + "data_abbrev": "d" } - ] - }, - { - "data_type": "half", - "data_abbrev": "h", - "acc_type": "float", - "acc_abbrev": "f", - "out_type": "float", - "out_abbrev": "f", + ], "_policy": [ { "policy_type": "Policy", @@ -178,16 +160,6 @@ "policy_type": "ColPolicy", "layout_abbrev": "col", "veclen": "2" - }, - { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "4" - }, - { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "4" } ] } @@ -210,12 +182,12 @@ ], "_data": [ { - "data_type": "float", - "data_abbrev": "f", "acc_type": "float", "acc_abbrev": "f", "out_type": "float", "out_abbrev": "f", + "data_type": "float", + "data_abbrev": "f", "_policy": [ { "policy_type": "Policy", @@ -250,12 +222,12 @@ ] }, { - "data_type": "double", - "data_abbrev": "d", "acc_type": "double", "acc_abbrev": "d", "out_type": "double", "out_abbrev": "d", + "data_type": "double", + "data_abbrev": "d", "_policy": [ { "policy_type": "Policy", diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json index 69959a8326..cf0157a633 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json @@ -7,14 +7,22 @@ "header_file": "distance/detail/distance_ops/l2_unexp.cuh" } ], - "_data": [ + "_acc_out": [ { - "data_type": "float", - "data_abbrev": "f", "acc_type": "float", "acc_abbrev": "f", "out_type": "float", "out_abbrev": "f", + "_data_input": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "half", + "data_abbrev": "h" + } + ], "_policy": [ { "policy_type": "Policy", @@ -49,42 +57,16 @@ ] }, { - "data_type": "double", - "data_abbrev": "d", "acc_type": "double", "acc_abbrev": "d", "out_type": "double", "out_abbrev": "d", - "_policy": [ - { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "1" - }, - { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "1" - }, - { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "2" - }, + "_data_input": [ { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "2" + "data_type": "double", + "data_abbrev": "d" } - ] - }, - { - "data_type": "half", - "data_abbrev": "h", - "acc_type": "float", - "acc_abbrev": "f", - "out_type": "float", - "out_abbrev": "f", + ], "_policy": [ { "policy_type": "Policy", @@ -105,16 +87,6 @@ "policy_type": "ColPolicy", "layout_abbrev": "col", "veclen": "2" - }, - { - "policy_type": "Policy", - "layout_abbrev": "row", - "veclen": "4" - }, - { - "policy_type": "ColPolicy", - "layout_abbrev": "col", - "veclen": "4" } ] }