diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7f9f88695c..227c2906cc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -848,6 +848,93 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + set(distance_ns "cuvs::distance::detail") + set(pairwise_matrix_jit_dir + "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/pairwise_matrix/jit_lto_kernels" + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_identity, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_fin_op_rbf_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/pairwise_matrix_rbf_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_pairwise_matrix<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_@out_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_fin_op_rbf, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/kernel_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "pairwise_matrix_compute_distance_epilog_@distance_abbrev@_data_@data_abbrev@_acc_@acc_abbrev@_index_@index_abbrev@_layout_@layout_abbrev@_veclen_@veclen@" + MATRIX_JSON_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_rbf_matrix.json" + KERNEL_INPUT_FILE "${pairwise_matrix_jit_dir}/compute_distance_epilog_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${distance_ns}::fragment_tag_compute_distance_epilog<${distance_ns}::tag_distance_@distance_abbrev@, ${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${distance_ns}::tag_layout_@layout_abbrev@, @veclen@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/pairwise_matrix/compute_distance_epilog_rbf" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) endblock() # Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index cb33e4109b..cbd3f72730 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -9,6 +9,7 @@ namespace cuvs::neighbors::detail { struct tag_f {}; struct tag_h {}; +struct tag_d {}; struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; @@ -16,6 +17,7 @@ struct tag_filter_bitset {}; struct tag_bitset_u32 {}; +struct tag_index_i32 {}; struct tag_index_u32 {}; struct tag_index_i64 {}; diff --git a/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp new file mode 100644 index 0000000000..d9b7f4ec8d --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/pairwise_matrix/pairwise_matrix_fragments.hpp @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::distance::detail { + +struct tag_layout_row {}; +struct tag_layout_col {}; + +struct tag_fin_op_identity {}; +struct tag_fin_op_rbf {}; + +struct tag_distance_canberra {}; +struct tag_distance_correlation {}; +struct tag_distance_cosine {}; +struct tag_distance_hamming_unexpanded {}; +struct tag_distance_hellinger_expanded {}; +struct tag_distance_jensen_shannon {}; +struct tag_distance_kl_divergence {}; +struct tag_distance_l1 {}; +struct tag_distance_l2_expanded {}; +struct tag_distance_l2_unexpanded {}; +struct tag_distance_l_inf {}; +struct tag_distance_lp_unexpanded {}; +struct tag_distance_russel_rao {}; + +template +struct fragment_tag_pairwise_matrix {}; + +template +struct fragment_tag_compute_distance {}; + +template +struct fragment_tag_compute_distance_epilog {}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/distance.cuh b/cpp/src/distance/detail/distance.cuh index ab06bc3ba1..a8b73d58e3 100644 --- a/cpp/src/distance/detail/distance.cuh +++ b/cpp/src/distance/detail/distance.cuh @@ -7,8 +7,6 @@ #include "distance_ops/all_ops.cuh" #include "pairwise_matrix/dispatch.cuh" -#include "pairwise_matrix/dispatch_sm60.cuh" -#include "pairwise_matrix/dispatch_sm80.cuh" #include #include #include diff --git a/cpp/src/distance/detail/distance_ops/correlation.cuh b/cpp/src/distance/detail/distance_ops/correlation.cuh index 39cfe4b8d2..7276845ddf 100644 --- a/cpp/src/distance/detail/distance_ops/correlation.cuh +++ b/cpp/src/distance/detail/distance_ops/correlation.cuh @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::sqrt #include // DI #include diff --git a/cpp/src/distance/detail/distance_ops/hellinger.cuh b/cpp/src/distance/detail/distance_ops/hellinger.cuh index 44e3e83375..35d8034c30 100644 --- a/cpp/src/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/src/distance/detail/distance_ops/hellinger.cuh @@ -1,9 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::sqrt #include // DI namespace cuvs::distance::detail::ops { diff --git a/cpp/src/distance/detail/distance_ops/l1.cuh b/cpp/src/distance/detail/distance_ops/l1.cuh index 7c0cbe0b6c..09b4d7b214 100644 --- a/cpp/src/distance/detail/distance_ops/l1.cuh +++ b/cpp/src/distance/detail/distance_ops/l1.cuh @@ -1,9 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::abs #include // DI namespace cuvs::distance::detail::ops { diff --git a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh index c7094cc1b8..49b8ad23d0 100644 --- a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::sqrt #include // DI #include diff --git a/cpp/src/distance/detail/distance_ops/l_inf.cuh b/cpp/src/distance/detail/distance_ops/l_inf.cuh index 2b2f09e8d6..03982183af 100644 --- a/cpp/src/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/src/distance/detail/distance_ops/l_inf.cuh @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include // raft::abs, raft::max #include // DI namespace cuvs::distance::detail::ops { diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index b2e886807d..a6fea8a017 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -1,8 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT +#include "pairwise_matrix/jit_lto_kernels/device_functions.cuh" +#endif #include // raft::linalg::Contractions_NT #include // ceildiv #include // RAFT_CUDA_TRY @@ -150,7 +153,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT + compute_distance_epilog( + distance_op, acc, regxn, regyn, tile_idx_n, tile_idx_m); +#else distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); +#endif // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { @@ -159,7 +167,12 @@ struct PairwiseDistances : public BaseClass { // Calculate distance_op epilog. // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT + compute_distance_epilog( + distance_op, acc, nullptr, nullptr, tile_idx_n, tile_idx_m); +#else distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); +#endif // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } @@ -203,7 +216,12 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { +#ifdef CUVS_DISTANCE_PAIRWISE_USE_JIT + compute_distance( + distance_op, acc[i][j], reg_x[i][v], reg_y[j][v]); +#else distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); +#endif } } } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index 79cbb78f43..aff78d87f9 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -9,27 +9,25 @@ * 1. Dispatch to the correct implementation of a kernel based on the * architecture of the device on which the kernel will be launched. For * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older + * be used on SM80+ and the JIT implementation that is used on older * architectures. * * 2. Provide concise function templates that can be instantiated in * src/distance/detail/pairwise_matrix/. Previously, * cuvs::distance::detail::distance was instantiated. The function * necessarily required a large set of include files, which slowed down the - * build. The cuvs::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. + * build. */ -#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op -#include "../pairwise_matrix/dispatch_sm60.cuh" // dispatch_sm60 -#include "../pairwise_matrix/params.cuh" // pairwise_matrix_params -#include // raft::util::arch::SM_* +#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op +#include "../pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh" // pairwise_matrix_jit_dispatch +#include // raft::util::arch::SM_* // NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. // Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in cuvs/distance/detail/distance.cuh -// and src/distance/detail/pairwise_matrix/dispatch_*.cu. +// Therefore, it is the including file's responsibility to include +// dispatch_sm80.cuh for CUTLASS-backed distance ops, as is done in +// src/distance/detail/pairwise_matrix/dispatch_*.cu. namespace cuvs::distance::detail { @@ -49,6 +47,19 @@ void pairwise_matrix_sm80_dispatch(OpT, SM_compat_t, cudaStream_t); +// This kernel is never launched. It only gives arch::kernel_virtual_arch a static kernel pointer +// from this TU/fatbin, without forcing JIT compilation just to decide whether the CUTLASS path is +// usable. +template +__global__ void pairwise_matrix_arch_probe_kernel() +{ +} + template (); + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op::value; if constexpr (cutlass_op_unavailable) { - // Always execute legacy kernels when no cutlass op is available - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + pairwise_matrix_jit_dispatch(distance_op, params, stream); } else { auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - - // TODO: the cutlass doesn't support the odd `k` on half DataT. - bool if_unsupported_on_half = (sizeof(DataT) == 2) && ((k % 2) != 0); - - if (if_unsupported_on_half) { - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else if (cutlass_range.contains(runtime_arch) && !if_unsupported_on_half) { + auto kernel = pairwise_matrix_arch_probe_kernel; + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + + // TODO: CUTLASS does not support odd `k` with half DataT. + bool unsupported_half = (sizeof(DataT) == 2) && ((k % 2) != 0); + + if (!unsupported_half && cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); + return; } + + pairwise_matrix_jit_dispatch(distance_op, params, stream); } } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json index 5412fc897f..bf6be6bed6 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_matrix.json @@ -30,67 +30,67 @@ { "op_type": "cuvs::distance::detail::ops::canberra_distance_op", "op_abbrev": "canberra", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::correlation_distance_op", "op_abbrev": "correlation", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::cosine_distance_op", "op_abbrev": "cosine", - "arch_includes": "#include \n#include " + "arch_includes": "#include " }, { "op_type": "cuvs::distance::detail::ops::hamming_distance_op", "op_abbrev": "hamming_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", "op_abbrev": "hellinger_expanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", "op_abbrev": "jensen_shannon", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::kl_divergence_op", "op_abbrev": "kl_divergence", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l1_distance_op", "op_abbrev": "l1", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "#include \n#include " + "arch_includes": "#include " }, { "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", "op_abbrev": "l2_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", "op_abbrev": "l_inf", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", "op_abbrev": "lp_unexpanded", - "arch_includes": "#include " + "arch_includes": "" }, { "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", "op_abbrev": "russel_rao", - "arch_includes": "#include " + "arch_includes": "" } ], "index_type": "int", @@ -119,7 +119,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", "op_abbrev": "l2_expanded", - "arch_includes": "#include \n#include " + "arch_includes": "#include " } ], "index_type": "int64_t", diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in index 879b221e4c..0e65689603 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_inst.cu.in @@ -14,7 +14,7 @@ using data_t = @data_type@; using acc_t = @acc_type@; using out_t = @out_type@; - using fin_op_t = cuvs::distance::kernels::rbf_fin_op; + using fin_op_t = cuvs::distance::kernels::rbf_fin_op; using index_t = @index_type@; using op_t = @op_type@; } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json index 876e824733..12213602f1 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf_matrix.json @@ -31,7 +31,7 @@ { "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", "op_abbrev": "l2_unexpanded", - "arch_includes": "#include " + "arch_includes": "" } ] } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh deleted file mode 100644 index 034aabef3e..0000000000 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ /dev/null @@ -1,74 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "dispatch_layout.cuh" // dispatch_layout -#include "kernel_sm60.cuh" // pairwise_matrix_sm60_wrapper -#include // raft::linalg::Policy4x4 - -#include // std::min - -namespace cuvs::distance::detail { - -template -pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and returns - // the corresponding kernel wrapper. The wrapper contains the launch - // parameters of the kernel: a pointer to the kernel function, grid size, - // block size, and shared memory size. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - - using RowPolicy = typename raft::linalg::Policy4x4::Policy; - using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; - using Policy = typename std::conditional::type; - - auto wrapper = - make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); - - return wrapper; - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - return dispatch_layout(params.is_row_major, vec_len, f); -} - -template -void pairwise_matrix_sm60_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); - - wrapper.launch(distance_op, params, stream); -} - -} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in new file mode 100644 index 0000000000..ffb3dc8c84 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_kernel.cu.in @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include <@header_file@> +#include +#include + +namespace { + +using data_t = @data_type@; +using acc_t = @acc_type@; +using index_t = @index_type@; +using op_t = @op_type@; +constexpr int veclen = @veclen@; +using policy_t = typename raft::linalg::Policy4x4::@policy_type@; + +} // namespace + +namespace cuvs::distance::detail { + +template <> +__device__ void compute_distance_epilog( + op_t distance_op, + acc_t acc[policy_t::AccRowsPerTh][policy_t::AccColsPerTh], + acc_t* regxn, + acc_t* regyn, + index_t grid_stride_x, + index_t grid_stride_y) +{ + distance_op.template epilog(acc, regxn, regyn, grid_stride_x, grid_stride_y); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json new file mode 100644 index 0000000000..3c03a381f7 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_matrix.json @@ -0,0 +1,262 @@ +[ + { + "_distance": [ + { + "distance_name": "canberra", + "distance_abbrev": "canberra", + "op_type": "cuvs::distance::detail::ops::canberra_distance_op", + "header_file": "distance/detail/distance_ops/canberra.cuh" + }, + { + "distance_name": "correlation", + "distance_abbrev": "correlation", + "op_type": "cuvs::distance::detail::ops::correlation_distance_op", + "header_file": "distance/detail/distance_ops/correlation.cuh" + }, + { + "distance_name": "cosine", + "distance_abbrev": "cosine", + "op_type": "cuvs::distance::detail::ops::cosine_distance_op", + "header_file": "distance/detail/distance_ops/cosine.cuh" + }, + { + "distance_name": "hamming_unexpanded", + "distance_abbrev": "hamming_unexpanded", + "op_type": "cuvs::distance::detail::ops::hamming_distance_op", + "header_file": "distance/detail/distance_ops/hamming.cuh" + }, + { + "distance_name": "hellinger_expanded", + "distance_abbrev": "hellinger_expanded", + "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", + "header_file": "distance/detail/distance_ops/hellinger.cuh" + }, + { + "distance_name": "jensen_shannon", + "distance_abbrev": "jensen_shannon", + "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", + "header_file": "distance/detail/distance_ops/jensen_shannon.cuh" + }, + { + "distance_name": "kl_divergence", + "distance_abbrev": "kl_divergence", + "op_type": "cuvs::distance::detail::ops::kl_divergence_op", + "header_file": "distance/detail/distance_ops/kl_divergence.cuh" + }, + { + "distance_name": "l1", + "distance_abbrev": "l1", + "op_type": "cuvs::distance::detail::ops::l1_distance_op", + "header_file": "distance/detail/distance_ops/l1.cuh" + }, + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + }, + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + }, + { + "distance_name": "l_inf", + "distance_abbrev": "l_inf", + "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", + "header_file": "distance/detail/distance_ops/l_inf.cuh" + }, + { + "distance_name": "lp_unexpanded", + "distance_abbrev": "lp_unexpanded", + "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", + "header_file": "distance/detail/distance_ops/lp_unexp.cuh" + }, + { + "distance_name": "russel_rao", + "distance_abbrev": "russel_rao", + "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", + "header_file": "distance/detail/distance_ops/russel_rao.cuh" + } + ], + "_acc_out": [ + { + "_data_input": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "half", + "data_abbrev": "h" + } + ], + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "_data_input": [ + { + "data_type": "double", + "data_abbrev": "d" + } + ], + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + } + ], + "_index": [ + { + "index_type": "int", + "index_abbrev": "i32" + } + ] + }, + { + "_distance": [ + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + } + ], + "_data": [ + { + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "data_type": "float", + "data_abbrev": "f", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "data_type": "double", + "data_abbrev": "d", + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] + } +] diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json new file mode 100644 index 0000000000..cf0157a633 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_epilog_rbf_matrix.json @@ -0,0 +1,100 @@ +{ + "_distance": [ + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + } + ], + "_acc_out": [ + { + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f", + "_data_input": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "half", + "data_abbrev": "h" + } + ], + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "4" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "4" + } + ] + }, + { + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d", + "_data_input": [ + { + "data_type": "double", + "data_abbrev": "d" + } + ], + "_policy": [ + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "1" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "1" + }, + { + "policy_type": "Policy", + "layout_abbrev": "row", + "veclen": "2" + }, + { + "policy_type": "ColPolicy", + "layout_abbrev": "col", + "veclen": "2" + } + ] + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] +} diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in new file mode 100644 index 0000000000..4d39801d40 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_kernel.cu.in @@ -0,0 +1,33 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include <@header_file@> +#include + +namespace { + +using data_t = @data_type@; +using acc_t = @acc_type@; +using index_t = @index_type@; +using op_t = @op_type@; + +} // namespace + +namespace cuvs::distance::detail { + +template <> +__device__ void compute_distance(op_t distance_op, + acc_t& acc, + data_t x, + data_t y) +{ + distance_op.core(acc, x, y); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json new file mode 100644 index 0000000000..0c3b0c4164 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_matrix.json @@ -0,0 +1,155 @@ +[ + { + "_distance": [ + { + "distance_name": "canberra", + "distance_abbrev": "canberra", + "op_type": "cuvs::distance::detail::ops::canberra_distance_op", + "header_file": "distance/detail/distance_ops/canberra.cuh" + }, + { + "distance_name": "correlation", + "distance_abbrev": "correlation", + "op_type": "cuvs::distance::detail::ops::correlation_distance_op", + "header_file": "distance/detail/distance_ops/correlation.cuh" + }, + { + "distance_name": "cosine", + "distance_abbrev": "cosine", + "op_type": "cuvs::distance::detail::ops::cosine_distance_op", + "header_file": "distance/detail/distance_ops/cosine.cuh" + }, + { + "distance_name": "hamming_unexpanded", + "distance_abbrev": "hamming_unexpanded", + "op_type": "cuvs::distance::detail::ops::hamming_distance_op", + "header_file": "distance/detail/distance_ops/hamming.cuh" + }, + { + "distance_name": "hellinger_expanded", + "distance_abbrev": "hellinger_expanded", + "op_type": "cuvs::distance::detail::ops::hellinger_distance_op", + "header_file": "distance/detail/distance_ops/hellinger.cuh" + }, + { + "distance_name": "jensen_shannon", + "distance_abbrev": "jensen_shannon", + "op_type": "cuvs::distance::detail::ops::jensen_shannon_distance_op", + "header_file": "distance/detail/distance_ops/jensen_shannon.cuh" + }, + { + "distance_name": "kl_divergence", + "distance_abbrev": "kl_divergence", + "op_type": "cuvs::distance::detail::ops::kl_divergence_op", + "header_file": "distance/detail/distance_ops/kl_divergence.cuh" + }, + { + "distance_name": "l1", + "distance_abbrev": "l1", + "op_type": "cuvs::distance::detail::ops::l1_distance_op", + "header_file": "distance/detail/distance_ops/l1.cuh" + }, + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + }, + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + }, + { + "distance_name": "l_inf", + "distance_abbrev": "l_inf", + "op_type": "cuvs::distance::detail::ops::l_inf_distance_op", + "header_file": "distance/detail/distance_ops/l_inf.cuh" + }, + { + "distance_name": "lp_unexpanded", + "distance_abbrev": "lp_unexpanded", + "op_type": "cuvs::distance::detail::ops::lp_unexp_distance_op", + "header_file": "distance/detail/distance_ops/lp_unexp.cuh" + }, + { + "distance_name": "russel_rao", + "distance_abbrev": "russel_rao", + "op_type": "cuvs::distance::detail::ops::russel_rao_distance_op", + "header_file": "distance/detail/distance_ops/russel_rao.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + }, + { + "data_type": "half", + "data_abbrev": "h", + "type_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + } + ], + "_index": [ + { + "index_type": "int", + "index_abbrev": "i32" + } + ] + }, + { + "_distance": [ + { + "distance_name": "l2_expanded", + "distance_abbrev": "l2_expanded", + "op_type": "cuvs::distance::detail::ops::l2_exp_distance_op", + "header_file": "distance/detail/distance_ops/l2_exp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] + } +] diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json new file mode 100644 index 0000000000..d140025078 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/compute_distance_rbf_matrix.json @@ -0,0 +1,45 @@ +{ + "_distance": [ + { + "distance_name": "l2_unexpanded", + "distance_abbrev": "l2_unexpanded", + "op_type": "cuvs::distance::detail::ops::l2_unexp_distance_op", + "header_file": "distance/detail/distance_ops/l2_unexp.cuh" + } + ], + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "type_abbrev": "f", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + }, + { + "data_type": "double", + "data_abbrev": "d", + "type_abbrev": "d", + "acc_type": "double", + "acc_abbrev": "d", + "out_type": "double", + "out_abbrev": "d" + }, + { + "data_type": "half", + "data_abbrev": "h", + "type_abbrev": "h", + "acc_type": "float", + "acc_abbrev": "f", + "out_type": "float", + "out_abbrev": "f" + } + ], + "_index": [ + { + "index_type": "int64_t", + "index_abbrev": "i64" + } + ] +} diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh new file mode 100644 index 0000000000..ff6479a05f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/device_functions.cuh @@ -0,0 +1,21 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::distance::detail { + +template +extern __device__ void compute_distance(OpT distance_op, AccT& acc, DataT x, DataT y); + +template +extern __device__ void compute_distance_epilog(OpT distance_op, + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + AccT* regxn, + AccT* regyn, + IdxT grid_stride_x, + IdxT grid_stride_y); + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh new file mode 100644 index 0000000000..17526e75b7 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_jit.cuh @@ -0,0 +1,210 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "pairwise_matrix_planner.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace cuvs::distance::detail { + +template +inline constexpr bool pairwise_matrix_jit_always_false_v = false; + +template +constexpr auto get_pairwise_scalar_type_tag() +{ + if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_f{}; + } else if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_d{}; + } else if constexpr (std::is_same_v || std::is_same_v) { + return cuvs::neighbors::detail::tag_h{}; + } else { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a scalar tag for this type"); + } +} + +template +constexpr auto get_pairwise_index_type_tag() +{ + if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_index_i32{}; + } else if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_index_i64{}; + } else { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have an index tag for this type"); + } +} + +template +using pairwise_layout_tag_t = std::conditional_t; + +template +struct pairwise_fin_op_tag { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a final-op tag for this type"); +}; + +template <> +struct pairwise_fin_op_tag { + using type = tag_fin_op_identity; +}; + +template +struct pairwise_fin_op_tag> { + using type = tag_fin_op_rbf; +}; + +template +using pairwise_fin_op_tag_t = typename pairwise_fin_op_tag::type; + +template +struct pairwise_distance_op_tag { + static_assert(pairwise_matrix_jit_always_false_v, + "Pairwise matrix JIT LTO does not have a distance-op tag for this type"); +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_canberra; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_correlation; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_cosine; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_hamming_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_hellinger_expanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_jensen_shannon; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_kl_divergence; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l1; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l2_expanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l2_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_l_inf; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_lp_unexpanded; +}; + +template +struct pairwise_distance_op_tag> { + using type = tag_distance_russel_rao; +}; + +template +using pairwise_distance_op_tag_t = typename pairwise_distance_op_tag::type; + +template +using pairwise_matrix_jit_kernel_t = void(OpT, pairwise_matrix_params); + +template +void pairwise_matrix_jit_dispatch(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ + using AccT = typename OpT::AccT; + + int vec_len = determine_vec_len(params); + + auto launch = [&](auto row_major, auto vec_len_aligned) { + constexpr int vec_len_op = OpT::expensive_inner_loop ? 1 : vec_len_aligned(); + constexpr int veclen = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = std::conditional_t; + + using DistanceTag = pairwise_distance_op_tag_t; + using DataTag = decltype(get_pairwise_scalar_type_tag()); + using AccTag = decltype(get_pairwise_scalar_type_tag()); + using OutTag = decltype(get_pairwise_scalar_type_tag()); + using IndexTag = decltype(get_pairwise_index_type_tag()); + using FinOpTag = pairwise_fin_op_tag_t; + using LayoutTag = pairwise_layout_tag_t; + + PairwiseMatrixPlanner + planner; + planner.add_entrypoint(); + planner.add_compute_distance_function(); + planner.add_compute_distance_epilog_function(); + + auto launcher = planner.get_launcher(); + + dim3 block(Policy::Nthreads); + int smem_size = OpT::template shared_mem_size(); + dim3 grid = + launchConfigGenerator(params.m, params.n, smem_size, launcher->get_kernel()); + + launcher->template dispatch>( + stream, grid, block, smem_size, distance_op, params); + }; + + dispatch_layout(params.is_row_major, vec_len, launch); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in new file mode 100644 index 0000000000..5432afb5c4 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_kernel.cu.in @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include <@header_file@> +#define CUVS_DISTANCE_PAIRWISE_USE_JIT +#include +#undef CUVS_DISTANCE_PAIRWISE_USE_JIT +#include +#include +#include + +namespace { + +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = raft::identity_op; +using op_t = @op_type@; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; +constexpr bool row_major = std::is_same_v; +using base_t = raft::linalg::Contractions_NT; + +} // namespace + +namespace cuvs::distance::detail { + +extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_matrix_kernel( + op_t distance_op, pairwise_matrix_params params) +{ + extern __shared__ char smem[]; + + auto epilog_op = raft::void_op(); + auto row_epilog_op = raft::void_op(); + + constexpr bool write_out = true; + PairwiseDistances + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, + smem, + distance_op, + epilog_op, + params.fin_op, + row_epilog_op); + obj.run(); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp new file mode 100644 index 0000000000..0d00b3eca6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::distance::detail { + +inline constexpr char kPairwiseMatrixJitEntrypoint[] = "pairwise_matrix_kernel"; + +template +struct PairwiseMatrixPlanner : AlgorithmPlanner { + using DistanceTag = DistanceTag_; + using DataTag = DataTag_; + using AccTag = AccTag_; + using OutTag = OutTag_; + using IndexTag = IndexTag_; + using FinOpTag = FinOpTag_; + using LayoutTag = LayoutTag_; + + static constexpr int Veclen = Veclen_; + + inline static LauncherJitCache launcher_jit_cache{}; + + PairwiseMatrixPlanner() : AlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} + + void add_entrypoint() + { + this->add_static_fragment>(); + } + + void add_compute_distance_function() + { + this->add_static_fragment< + fragment_tag_compute_distance>(); + } + + void add_compute_distance_epilog_function() + { + this->add_static_fragment>(); + } +}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in new file mode 100644 index 0000000000..30baae0dc3 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_rbf_kernel.cu.in @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include <@header_file@> +#include +#define CUVS_DISTANCE_PAIRWISE_USE_JIT +#include +#undef CUVS_DISTANCE_PAIRWISE_USE_JIT +#include +#include +#include + +namespace { + +using data_t = @data_type@; +using acc_t = @acc_type@; +using out_t = @out_type@; +using index_t = @index_type@; +using fin_op_t = cuvs::distance::kernels::rbf_fin_op; +using op_t = @op_type@; +constexpr int veclen = @veclen@; +using policies_t = raft::linalg::Policy4x4; +using policy_t = typename policies_t::@policy_type@; +constexpr bool row_major = std::is_same_v; +using base_t = raft::linalg::Contractions_NT; + +} // namespace + +namespace cuvs::distance::detail { + +extern "C" __global__ __launch_bounds__(policy_t::Nthreads, 2) void pairwise_matrix_kernel( + op_t distance_op, pairwise_matrix_params params) +{ + extern __shared__ char smem[]; + + auto epilog_op = raft::void_op(); + auto row_epilog_op = raft::void_op(); + + constexpr bool write_out = true; + PairwiseDistances + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, + smem, + distance_op, + epilog_op, + params.fin_op, + row_epilog_op); + obj.run(); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh deleted file mode 100644 index e7c9facef6..0000000000 --- a/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ /dev/null @@ -1,144 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "../pairwise_distance_base.cuh" // PairwiseDistances -#include "params.cuh" // pairwise_matrix_params -#include // raft::void_op -#include // raft::util::arch::SM_compute_arch - -#include // assert - -namespace cuvs::distance::detail { - -template -__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL - pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) -{ - // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { - assert(false); - return; - } - - extern __shared__ char smem[]; - - // The epilog is already provided by distance_op. Do not provide additional - // epilogs. - auto epilog_op = raft::void_op(); - // No support for row_epilog_op. - auto row_epilog_op = raft::void_op(); - - // Always write output - constexpr bool write_out = true; - PairwiseDistances - obj(params.x, - params.y, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.x_norm, - params.y_norm, - params.out, - smem, - distance_op, - epilog_op, - params.fin_op, - row_epilog_op); - obj.run(); -} - -// The type of a pointer to the pairwise matrix kernel. The following template -// arguments are type-erased: -// -// - The kernel policy -// - row_major -// - SM_compat_t -template -using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); - -// A wrapper for the pairwise matrix kernel launch. Includes kernel launch -// parameters. -template -struct pairwise_matrix_sm60_wrapper { - dim3 grid; - dim3 block; - int smem_size; - pairwise_matrix_kernel_t kernel_ptr; - - void launch(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) - { - kernel_ptr<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -/** @brief: Create kernel launch wrapper for pairwise matrix kernel - * - * This can be used to type-erase the kernel execution policy, row_major, and SM - * compatibility range. - * - * @tparam Policy: Kernel execution policy - * @tparam row_major: Indicates whether input matrices are row major - * @tparam OpT: Type of distance operation - * @tparam IdxT: Index type - * @tparam DataT: Data type - * @tparam OutT: Output data type - * @tparam FinOpT: Final operation type - * @tparam SM_compat_t: Type of the SM architecture compatibility - * - * @param distance_op: Distance operation - * @param params: Parameters - * @param sm_compat_range: Which SM architectures to compile for. - */ -template -pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - dim3 block(Policy::Nthreads); - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = OpT::template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - return pairwise_matrix_sm60_wrapper{ - grid, block, smem_size, kernel}; -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/sparse/l2_distance.cuh b/cpp/src/distance/detail/sparse/l2_distance.cuh index 89d7c978a4..3b06dcb895 100644 --- a/cpp/src/distance/detail/sparse/l2_distance.cuh +++ b/cpp/src/distance/detail/sparse/l2_distance.cuh @@ -9,6 +9,7 @@ #include "ip_distance.cuh" #include +#include // raft::sqrt #include #include #include