diff --git a/.gitignore b/.gitignore index 3627558ff5..0066d2b89a 100644 --- a/.gitignore +++ b/.gitignore @@ -72,7 +72,9 @@ docs/source/_static/rust # clang tooling compile_commands.json -.clangd/ + + + # serialized ann indexes brute_force_index @@ -86,5 +88,8 @@ ivf_pq_index /datasets/ /*.json +# clangd +*/.clangd + # java .classpath diff --git a/cpp/.clangd b/cpp/.clangd deleted file mode 100644 index 7c4fe036dd..0000000000 --- a/cpp/.clangd +++ /dev/null @@ -1,65 +0,0 @@ -# https://clangd.llvm.org/config - -# Apply a config conditionally to all C files -If: - PathMatch: .*\.(c|h)$ - ---- - -# Apply a config conditionally to all C++ files -If: - PathMatch: .*\.(c|h)pp - ---- - -# Apply a config conditionally to all CUDA files -If: - PathMatch: .*\.cuh? -CompileFlags: - Add: - - "-x" - - "cuda" - # No error on unknown CUDA versions - - "-Wno-unknown-cuda-version" - # Allow variadic CUDA functions - - "-Xclang=-fcuda-allow-variadic-functions" -Diagnostics: - Suppress: - - "variadic_device_fn" - - "attributes_not_allowed" - ---- - -# Tweak the clangd parse settings for all files -CompileFlags: - Add: - # report all errors - - "-ferror-limit=0" - - "-fmacro-backtrace-limit=0" - - "-ftemplate-backtrace-limit=0" - # Skip the CUDA version check - - "--no-cuda-version-check" - Remove: - # remove gcc's -fcoroutines - - -fcoroutines - # remove nvc++ flags unknown to clang - - "-gpu=*" - - "-stdpar*" - # remove nvcc flags unknown to clang - - "-arch*" - - "-gencode*" - - "--generate-code*" - - "-ccbin*" - - "-t=*" - - "--threads*" - - "-Xptxas*" - - "-Xcudafe*" - - "-Xfatbin*" - - "-Xcompiler*" - - "--diag-suppress*" - - "--diag_suppress*" - - "--compiler-options*" - - "--expt-extended-lambda" - - "--expt-relaxed-constexpr" - - "-forward-unknown-to-host-compiler" - - "-Werror=cross-execution-space-call" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..f958edf022 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -268,12 +268,12 @@ if(NOT BUILD_CPU_ONLY) INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in" OUTPUT_FILE_FORMAT - "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@.cu" + "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@_smem_@smem_abbrev@.cu" ) generate_string_matrix( cagra_compute_distance_vpq_selector_template_params ITEM_FORMAT - "\nvpq_descriptor_spec" + "\nvpq_descriptor_spec" GLUE "," MATRIX_JSON_FILE @@ -282,7 +282,7 @@ if(NOT BUILD_CPU_ONLY) generate_string_matrix( cagra_compute_distance_vpq_template_inst ITEM_FORMAT - "extern template struct vpq_descriptor_spec@semicolon@" + "extern template struct vpq_descriptor_spec@semicolon@" GLUE "\n" MATRIX_JSON_FILE @@ -688,13 +688,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" @@ -704,13 +704,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index ea26ceed78..9977484742 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -320,10 +320,12 @@ void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::cagra::index } // Parse build-algo-specific parameters and use them to decide on the algo type - nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); - nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); - nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); - nlohmann::json ace_conf = collect_conf_with_prefix(conf, "ace_"); + nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); + nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); + nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); + nlohmann::json ace_conf = collect_conf_with_prefix(conf, "ace_"); + nlohmann::json build_compression_conf = collect_conf_with_prefix(conf, "build_compression_"); + nlohmann::json build_search_conf = collect_conf_with_prefix(conf, "build_search_"); // When graph_build_algo is not specified, leave graph_build_params as monostate so the // CAGRA build uses AUTO selection (NN_DESCENT or IVF_PQ based on dataset/heuristics). @@ -354,6 +356,94 @@ void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::cagra::index } else if constexpr (std::is_same_v) { parse_build_param(nn_descent_conf, arg); + } else if constexpr (std::is_same_v< + U, + cuvs::neighbors::graph_build_params::iterative_search_params>) { + if (!build_compression_conf.empty()) { + auto vpq_pams = arg.build_compression.value_or(cuvs::neighbors::vpq_params{}); + parse_build_param(build_compression_conf, vpq_pams); + arg.build_compression.emplace(vpq_pams); + } + if (build_search_conf.contains("width")) { + arg.search_width = build_search_conf.at("width"); + } + if (build_search_conf.contains("max_iterations")) { + arg.max_iterations = build_search_conf.at("max_iterations"); + } + if (build_search_conf.contains("min_iterations")) { + arg.min_iterations = build_search_conf.at("min_iterations"); + } + if (build_search_conf.contains("itopk")) { arg.itopk_size = build_search_conf.at("itopk"); } + if (build_search_conf.contains("max_queries")) { + arg.max_queries = build_search_conf.at("max_queries"); + } + if (build_search_conf.contains("team_size")) { + arg.team_size = build_search_conf.at("team_size"); + } + if (build_search_conf.contains("thread_block_size")) { + arg.thread_block_size = build_search_conf.at("thread_block_size"); + } + if (build_search_conf.contains("hashmap_min_bitlen")) { + arg.hashmap_min_bitlen = build_search_conf.at("hashmap_min_bitlen"); + } + if (build_search_conf.contains("hashmap_max_fill_rate")) { + arg.hashmap_max_fill_rate = build_search_conf.at("hashmap_max_fill_rate"); + } + if (build_search_conf.contains("num_random_samplings")) { + arg.num_random_samplings = build_search_conf.at("num_random_samplings"); + } + if (build_search_conf.contains("persistent")) { + arg.persistent = build_search_conf.at("persistent"); + } + if (build_search_conf.contains("persistent_lifetime")) { + arg.persistent_lifetime = build_search_conf.at("persistent_lifetime"); + } + if (build_search_conf.contains("persistent_device_usage")) { + arg.persistent_device_usage = build_search_conf.at("persistent_device_usage"); + } + if (build_search_conf.contains("algo")) { + std::string algo = build_search_conf.at("algo"); + if (algo == "single_cta") { + arg.algo = cuvs::neighbors::cagra::search_algo::SINGLE_CTA; + } else if (algo == "multi_cta") { + arg.algo = cuvs::neighbors::cagra::search_algo::MULTI_CTA; + } else if (algo == "multi_kernel") { + arg.algo = cuvs::neighbors::cagra::search_algo::MULTI_KERNEL; + } else if (algo == "auto") { + arg.algo = cuvs::neighbors::cagra::search_algo::AUTO; + } + } + if (build_search_conf.contains("hashmap_mode")) { + std::string mode = build_search_conf.at("hashmap_mode"); + if (mode == "hash") { + arg.hashmap_mode = cuvs::neighbors::cagra::hash_mode::HASH; + } else if (mode == "small") { + arg.hashmap_mode = cuvs::neighbors::cagra::hash_mode::SMALL; + } else if (mode == "auto") { + arg.hashmap_mode = cuvs::neighbors::cagra::hash_mode::AUTO; + } + } + // Whether to shuffle the (compressed) dataset before the iterative build loop. + if (build_search_conf.contains("shuffle_dataset")) { + arg.shuffle_dataset = build_search_conf.at("shuffle_dataset").get(); + } + // Precision of the codebook/query in shared memory for the VPQ search used during + // the iterative build. Accepts an integer code (0=F16, 1=E5M2) or a string. + if (build_search_conf.contains("smem_dtype")) { + const auto& sd = build_search_conf.at("smem_dtype"); + if (sd.is_number_integer()) { + arg.smem_dtype = static_cast(sd.get()); + } else { + std::string s = sd.get(); + if (s == "f16" || s == "F16" || s == "fp16" || s == "half") { + arg.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } else if (s == "e5m2" || s == "E5M2" || s == "fp8") { + arg.smem_dtype = cuvs::neighbors::cagra::internal_dtype::E5M2; + } else { + throw std::runtime_error("invalid value for build_search smem_dtype: " + s); + } + } + } } }, params.graph_build_params); diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp index 0b42d79379..67cdd38783 100644 --- a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -16,6 +16,8 @@ struct tag_metric_cosine {}; struct tag_metric_hamming {}; struct tag_codebook_none {}; struct tag_codebook_half {}; +struct tag_smem_f16 {}; +struct tag_smem_e5m2 {}; struct tag_metric_l1 {}; struct tag_norm_noop {}; struct tag_norm_cosine {}; @@ -33,7 +35,8 @@ template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_setup_workspace {}; template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_compute_distance {}; template diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 471807c0c7..1e6e87b2e5 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -31,10 +31,172 @@ #include #include +namespace CUVS_EXPORT cuvs { +namespace neighbors { +namespace cagra { + +/** + * @defgroup cagra_cpp_search_params CAGRA index search parameters + * @{ + */ + +enum class search_algo { + /** For large batch sizes. */ + SINGLE_CTA = 0, + /** For small batch sizes. */ + MULTI_CTA = 1, + MULTI_KERNEL = 2, + AUTO = 100 +}; + +enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 }; + +enum class internal_dtype { F16 = 0, E5M2 = 1 }; + +struct search_params : cuvs::neighbors::search_params { + /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ + size_t max_queries = 0; + + /** Number of intermediate search results retained during the search. + * + * This is the main knob to adjust trade off between accuracy and search speed. + * Higher values improve the search accuracy. + */ + size_t itopk_size = 64; + + /** Upper limit of search iterations. Auto select when 0.*/ + size_t max_iterations = 0; + + // In the following we list additional search parameters for fine tuning. + // Reasonable default values are automatically chosen. + + /** Which search implementation to use. */ + search_algo algo = search_algo::AUTO; + + /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. */ + size_t team_size = 0; + + /** Number of graph nodes to select as the starting point for the search in each iteration. aka + * search width?*/ + size_t search_width = 1; + /** Lower limit of search iterations. */ + size_t min_iterations = 0; + + /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ + size_t thread_block_size = 0; + /** Hashmap type. Auto selection when AUTO. */ + hash_mode hashmap_mode = hash_mode::AUTO; + /** Lower limit of hashmap bit length. More than 8. */ + size_t hashmap_min_bitlen = 0; + /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ + float hashmap_max_fill_rate = 0.5; + + /** Number of iterations of initial random seed node selection. 1 or more. */ + uint32_t num_random_samplings = 1; + /** Bit mask used for initial random seed node selection. */ + uint64_t rand_xor_mask = 0x128394; + + /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ + bool persistent = false; + /** Persistent kernel: time in seconds before the kernel stops if no requests received. */ + float persistent_lifetime = 2; + /** + * Set the fraction of maximum grid size used by persistent kernel. + * Value 1.0 means the kernel grid size is maximum possible for the selected device. + * The value must be greater than 0.0 and not greater than 1.0. + * + * One may need to run other kernels alongside this persistent kernel. This parameter can + * be used to reduce the grid size of the persistent kernel to leave a few SMs idle. + * Note: running any other work on GPU alongside with the persistent kernel makes the setup + * fragile. + * - Running another kernel in another thread usually works, but no progress guaranteed + * - Any CUDA allocations block the context (this issue may be obscured by using pools) + * - Memory copies to not-pinned host memory may block the context + * + * Even when we know there are no other kernels working at the same time, setting + * kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. + * If you suspect this is an issue, you can reduce this number to ~0.9 without a significant + * impact on the throughput. + */ + float persistent_device_usage = 1.0; + + /** + * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. + * The value must be equal to or greater than 0.0 and less than 1.0. Default value is + * negative, in which case the filtering rate is automatically calculated when possible. + * For `filtering::udf_filter`, CAGRA uses `udf_filter::filtering_rate` when this value is + * negative. If both values are negative, CAGRA assumes 0.0 because a UDF's selectivity cannot be + * inferred from the source string. + */ + float filtering_rate = -1.0; + + /** Data type of the query vector and codebook table on shared memory. Currently, only VPQ + * supports FP8. **/ + internal_dtype smem_dtype = internal_dtype::F16; +}; + +/** + * @} + */ + +} // namespace cagra +} // namespace neighbors +} // namespace CUVS_EXPORT cuvs + namespace CUVS_EXPORT cuvs { namespace neighbors { namespace graph_build_params { -using iterative_search_params = cuvs::neighbors::search_params; +/** + * Parameters for the iterative CAGRA graph build algorithm. + * + * Inherits from cagra::search_params so that all search tuning knobs + * (search_width, max_iterations, itopk_size, etc.) are available for + * controlling the search-and-optimize loop during graph construction. + * The defaults are tuned for the build loop (e.g. search_width=1, + * max_iterations=8) and may differ from the regular search defaults. + * + * `build_compression` controls the VPQ parameters applied to the dataset + * *while building the graph*. This is independent of `index_params::compression`, + * which controls the compression of the dataset stored in the final index. + */ +struct iterative_search_params : cuvs::neighbors::cagra::search_params { + /** + * Optional VPQ compression parameters used during iterative graph construction. + * + * When set, the dataset is compressed with these parameters for the + * search-and-optimize loop. When std::nullopt (default), the builder + * falls back to `index_params::compression` (original behaviour). + */ + std::optional build_compression = std::nullopt; + + /** + * Whether to shuffle the dataset before building the graph. + * + * When enabled, the compressed dataset is randomly permuted before graph + * construction begins. This can improve graph quality by breaking any + * spatial locality in the original dataset ordering that might cause + * the iterative builder to get stuck in local optima during early + * iterations. + * + * After graph construction, the node indices in the graph are remapped + * back to the original dataset ordering. + * + * Only applies when compression is enabled (build_compression or + * index_params::compression is set). + */ + bool shuffle_dataset = true; + + iterative_search_params() + { + this->search_width = 1; + this->max_iterations = 8; + // itopk_size controls the search during the *growing* iterations of the build loop. + // 0 (default) means auto-select per iteration (max(graph_degree + 32, 128)); a nonzero + // value overrides it for the growing iterations. The final iteration always uses a fixed + // itopk tied to the output topk, regardless of this value. + this->itopk_size = 0; + } +}; /** Specialized parameters for ACE (Augmented Core Extraction) graph build */ struct ace_params { @@ -191,6 +353,14 @@ struct index_params : cuvs::neighbors::index_params { */ bool guarantee_connectivity = false; + /** + * Whether to skip graph optimization (pruning, reverse edges, MST) during non-final iterations + * of iterative graph building. When true, search results are copied directly into the device + * graph without host round-trips. Only applies to iterative_search_params graph builds; the + * final iteration always runs full optimization. + */ + bool skip_graph_optimization = false; + /** * Whether to add the dataset content to the index, i.e.: * @@ -256,104 +426,6 @@ struct index_params : cuvs::neighbors::index_params { cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded); }; -/** - * @} - */ - -/** - * @defgroup cagra_cpp_search_params CAGRA index search parameters - * @{ - */ - -enum class search_algo { - /** For large batch sizes. */ - SINGLE_CTA = 0, - /** For small batch sizes. */ - MULTI_CTA = 1, - MULTI_KERNEL = 2, - AUTO = 100 -}; - -enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 }; - -struct search_params : cuvs::neighbors::search_params { - /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ - size_t max_queries = 0; - - /** Number of intermediate search results retained during the search. - * - * This is the main knob to adjust trade off between accuracy and search speed. - * Higher values improve the search accuracy. - */ - size_t itopk_size = 64; - - /** Upper limit of search iterations. Auto select when 0.*/ - size_t max_iterations = 0; - - // In the following we list additional search parameters for fine tuning. - // Reasonable default values are automatically chosen. - - /** Which search implementation to use. */ - search_algo algo = search_algo::AUTO; - - /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. */ - size_t team_size = 0; - - /** Number of graph nodes to select as the starting point for the search in each iteration. aka - * search width?*/ - size_t search_width = 1; - /** Lower limit of search iterations. */ - size_t min_iterations = 0; - - /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ - size_t thread_block_size = 0; - /** Hashmap type. Auto selection when AUTO. */ - hash_mode hashmap_mode = hash_mode::AUTO; - /** Lower limit of hashmap bit length. More than 8. */ - size_t hashmap_min_bitlen = 0; - /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ - float hashmap_max_fill_rate = 0.5; - - /** Number of iterations of initial random seed node selection. 1 or more. */ - uint32_t num_random_samplings = 1; - /** Bit mask used for initial random seed node selection. */ - uint64_t rand_xor_mask = 0x128394; - - /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ - bool persistent = false; - /** Persistent kernel: time in seconds before the kernel stops if no requests received. */ - float persistent_lifetime = 2; - /** - * Set the fraction of maximum grid size used by persistent kernel. - * Value 1.0 means the kernel grid size is maximum possible for the selected device. - * The value must be greater than 0.0 and not greater than 1.0. - * - * One may need to run other kernels alongside this persistent kernel. This parameter can - * be used to reduce the grid size of the persistent kernel to leave a few SMs idle. - * Note: running any other work on GPU alongside with the persistent kernel makes the setup - * fragile. - * - Running another kernel in another thread usually works, but no progress guaranteed - * - Any CUDA allocations block the context (this issue may be obscured by using pools) - * - Memory copies to not-pinned host memory may block the context - * - * Even when we know there are no other kernels working at the same time, setting - * kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. - * If you suspect this is an issue, you can reduce this number to ~0.9 without a significant - * impact on the throughput. - */ - float persistent_device_usage = 1.0; - - /** - * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. - * The value must be equal to or greater than 0.0 and less than 1.0. Default value is - * negative, in which case the filtering rate is automatically calculated when possible. - * For `filtering::udf_filter`, CAGRA uses `udf_filter::filtering_rate` when this value is - * negative. If both values are negative, CAGRA assumes 0.0 because a UDF's selectivity cannot be - * inferred from the source string. - */ - float filtering_rate = -1.0; -}; - /** * @} */ diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 2fd804f115..1e5ca5a159 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -98,6 +98,18 @@ struct vpq_params { * The max number of data points to use per VQ cluster during training. */ uint32_t max_train_points_per_vq_cluster = 1024; + + friend bool operator==(const vpq_params& a, const vpq_params& b) + { + return a.pq_bits == b.pq_bits && a.pq_dim == b.pq_dim && a.vq_n_centers == b.vq_n_centers && + a.kmeans_n_iters == b.kmeans_n_iters && + a.vq_kmeans_trainset_fraction == b.vq_kmeans_trainset_fraction && + a.pq_kmeans_trainset_fraction == b.pq_kmeans_trainset_fraction && + a.pq_kmeans_type == b.pq_kmeans_type && + a.max_train_points_per_pq_code == b.max_train_points_per_pq_code && + a.max_train_points_per_vq_cluster == b.max_train_points_per_vq_cluster; + } + friend bool operator!=(const vpq_params& a, const vpq_params& b) { return !(a == b); } }; /** @} */ // end group cagra_cpp_index_params diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 96ff8344d3..10e44377e3 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -17,8 +17,17 @@ #include #include #include +#include +#include +#include +#include #include +#include +#include +#include +#include + #include #include #include @@ -51,6 +60,32 @@ namespace cuvs::neighbors::cagra::detail { constexpr double to_mib(size_t bytes) { return static_cast(bytes) / (1 << 20); } constexpr double to_gib(size_t bytes) { return static_cast(bytes) / (1 << 30); } +// Functor to remap indices using a permutation lookup table +template +struct remap_indices_op { + const IdxT* perm; + __host__ __device__ IdxT operator()(IdxT idx) const { return perm[idx]; } +}; + +// Functor to compute scattered output index for graph row reordering +template +struct graph_scatter_index_op { + const IdxT* perm; + int64_t degree; + __host__ __device__ int64_t operator()(int64_t idx) const + { + int64_t row = idx / degree; + int64_t col = idx % degree; + return static_cast(perm[row]) * degree + col; + } +}; + +// Functor to convert int64_t to IdxT +template +struct cast_to_idx_op { + __host__ __device__ IdxT operator()(int64_t v) const { return static_cast(v); } +}; + template void check_graph_degree(size_t& intermediate_degree, size_t& graph_degree, size_t dataset_size) { @@ -1977,6 +2012,110 @@ struct mmap_owner { size_t size_; }; +template +__global__ void kern_reconstruct_vpq_queries(const uint8_t* encoded_data, + uint32_t encoded_row_len, + const MathT* vq_codebook, + const MathT* pq_codebook, + uint32_t dim, + uint32_t pq_len, + uint64_t offset, + uint32_t batch_size, + T* output) +{ + const uint64_t batch_idx = blockIdx.x; + if (batch_idx >= batch_size) return; + const uint64_t vec_idx = offset + batch_idx; + const uint8_t* vec_data = encoded_data + vec_idx * encoded_row_len; + const uint32_t vq_code = *reinterpret_cast(vec_data); + const uint8_t* pq_codes = vec_data + sizeof(uint32_t); + const MathT* vq_centroid_ptr = vq_codebook + static_cast(vq_code) * dim; + + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + uint32_t j = d / pq_len; + uint32_t k = d % pq_len; + float val = static_cast(vq_centroid_ptr[d]) + + static_cast(pq_codebook[static_cast(pq_codes[j]) * pq_len + k]); + output[batch_idx * dim + d] = static_cast(val); + } +} + +template +void reconstruct_vpq_queries(raft::resources const& res, + const vpq_dataset& vpq_dset, + uint64_t offset, + uint32_t batch_size, + raft::device_matrix_view output) +{ + const uint32_t dim = vpq_dset.dim(); + const uint32_t pq_len = vpq_dset.pq_len(); + const uint32_t threads = std::min(dim, 256u); + + kern_reconstruct_vpq_queries + <<>>( + vpq_dset.data.data_handle(), + vpq_dset.encoded_row_length(), + vpq_dset.vq_code_book.data_handle(), + vpq_dset.pq_code_book.data_handle(), + dim, + pq_len, + offset, + batch_size, + output.data_handle()); +} + +template +void search_and_optimize(raft::resources const& res, + const cuvs::neighbors::cagra::search_params& search_params, + const index& idx, + raft::device_matrix_view dev_query_view, + raft::device_matrix_view dev_neighbors, + raft::device_matrix_view dev_distances, + raft::device_matrix& dev_output_graph, + size_t curr_query_size, + size_t next_graph_degree, + size_t curr_topk, + uint64_t max_chunk_size) +{ + auto stream = raft::resource::get_cuda_stream(res); + + auto dev_knn_graph = raft::make_device_matrix(res, curr_query_size, curr_topk); + + auto query_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + dev_query_view.data_handle(), + static_cast(curr_query_size), + static_cast(dev_query_view.extent(1)), + max_chunk_size, + stream, + raft::resource::get_workspace_resource_ref(res)); + for (const auto& batch : query_batch) { + auto batch_dev_query_view = raft::make_device_matrix_view( + batch.data(), batch.size(), dev_query_view.extent(1)); + auto batch_dev_neighbors_view = raft::make_device_matrix_view( + dev_neighbors.data_handle(), batch.size(), curr_topk); + auto batch_dev_distances_view = raft::make_device_matrix_view( + dev_distances.data_handle(), batch.size(), curr_topk); + + cuvs::neighbors::cagra::search(res, + search_params, + idx, + batch_dev_query_view, + batch_dev_neighbors_view, + batch_dev_distances_view); + + raft::copy(dev_knn_graph.data_handle() + batch.offset() * curr_topk, + batch_dev_neighbors_view.data_handle(), + batch.size() * curr_topk, + stream); + } + + dev_output_graph = + raft::make_device_matrix(res, curr_query_size, next_graph_degree); + + graph::optimize(res, dev_knn_graph.view(), dev_output_graph.view(), false); +} + template (params.graph_build_params); + const auto& build_compression = + iter_params.build_compression.has_value() ? iter_params.build_compression : params.compression; + + if (build_compression.has_value()) { + const auto& bc = *build_compression; + RAFT_LOG_INFO( + "Build compression params: pq_bits=%u, pq_dim=%u, vq_n_centers=%u, kmeans_n_iters=%u, " + "vq_kmeans_trainset_fraction=%.4f, pq_kmeans_trainset_fraction=%.4f, " + "max_train_points_per_pq_code=%u, max_train_points_per_vq_cluster=%u%s", + bc.pq_bits, + bc.pq_dim, + bc.vq_n_centers, + bc.kmeans_n_iters, + bc.vq_kmeans_trainset_fraction, + bc.pq_kmeans_trainset_fraction, + bc.max_train_points_per_pq_code, + bc.max_train_points_per_vq_cluster, + iter_params.build_compression.has_value() ? " (from build_compression)" + : " (from compression)"); + } else { + RAFT_LOG_INFO("Build compression: disabled (uncompressed build)"); + } + RAFT_LOG_INFO("Build search params: search_width=%zu, max_iterations=%zu", + iter_params.search_width, + iter_params.max_iterations); + auto cagra_graph = raft::make_host_matrix(0, 0); // Iteratively improve the accuracy of the graph by repeatedly running @@ -2039,6 +2209,17 @@ auto iterative_build_graph( RAFT_LOG_DEBUG("# graph_degree = %lu", (uint64_t)graph_degree); RAFT_LOG_DEBUG("# topk = %lu", (uint64_t)topk); + // A fixed itopk_size (0 = auto) governs the growing iterations, which build graphs of degree + // ~graph_degree/2 and thus request topk ~= graph_degree/2 + 1; the search planner requires + // topk <= itopk_size. (The full-size iterations override itopk internally, so they are not + // constrained by this value.) + RAFT_EXPECTS(iter_params.itopk_size == 0 || + iter_params.itopk_size >= graph_degree / 2 + 1, + "iterative build search itopk_size (%zu) must be 0 (auto) or >= " + "graph_degree / 2 + 1 (%zu)", + (size_t)iter_params.itopk_size, + (size_t)(graph_degree / 2 + 1)); + // Create an initial graph. The initial graph created here is not suitable for // searching, but connectivity is guaranteed. auto offset = raft::make_host_vector(small_graph_degree); @@ -2059,28 +2240,131 @@ auto iterative_build_graph( } } - // Allocate memory for neighbors list using Transparent HugePage - constexpr size_t thp_size = 2 * 1024 * 1024; - size_t byte_size = sizeof(IdxT) * final_graph_size * topk; - if (byte_size % thp_size) { byte_size += thp_size - (byte_size % thp_size); } - mmap_owner neighbors_list(byte_size); - IdxT* neighbors_ptr = (IdxT*)neighbors_list.data(); - memset(neighbors_ptr, 0, byte_size); - bool flag_last = false; auto curr_graph_size = initial_graph_size; + + auto dev_graph = raft::make_device_matrix(res, 0, 0); + bool use_device_graph = false; + + // Generate the compressed index once if compression is enabled + const uint64_t dataset_dim = dev_dataset.extent(1); + std::optional> idx_opt; + + // Optional shuffle permutation for randomizing dataset order during build. + // inverse_perm[shuffled_idx] = original_idx + // perm[shuffled_idx] = original_idx, used to unshuffle the graph after build + auto dev_perm = raft::make_device_vector(res, 0); + bool dataset_shuffled = false; + + // Warn if shuffle is requested but compression is not enabled + if (iter_params.shuffle_dataset && !build_compression.has_value()) { + RAFT_LOG_WARN("shuffle_dataset is only supported with compression enabled; ignoring"); + } + + if (build_compression.has_value()) { + auto start = std::chrono::high_resolution_clock::now(); + RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded, + "VPQ compression is only supported with L2Expanded distance mertric"); + + // Build the VPQ compressed dataset + auto vpq_dset = + cuvs::preprocessing::quantize::pq::vpq_build(res, *build_compression, dev_dataset); + + // Optionally shuffle the compressed dataset to break spatial locality + if (iter_params.shuffle_dataset) { + auto shuffle_start = std::chrono::high_resolution_clock::now(); + RAFT_LOG_INFO("Shuffling compressed dataset to randomize build order..."); + + auto stream = raft::resource::get_cuda_stream(res); + const auto n_rows = vpq_dset.data.extent(0); + const auto row_len = vpq_dset.data.extent(1); + + // Generate random permutation: perm[i] = source index for output row i + // i.e., shuffled_data[i] = original_data[perm[i]] + // So perm maps: shuffled_idx -> original_idx + // Use int64_t for permutation to match vpq_dataset's index type + auto dev_perm_i64 = raft::make_device_vector(res, n_rows); + + // Use legacy permute API to generate permutation indices only (out=nullptr, in=nullptr) + // This just fills dev_perm_i64 with a random permutation of [0, n_rows) + raft::random::permute(dev_perm_i64.data_handle(), + static_cast(nullptr), + static_cast(nullptr), + static_cast(row_len), + static_cast(n_rows), + true, + stream); + + // Apply permutation to VPQ data: shuffled_data[i] = original_data[perm[i]]. + // NOTE: use an out-of-place device gather into a temporary buffer rather than the + // in-place gather overload. The in-place overload uses a host-orchestrated, + // double-buffered, multi-stream path that races here and triggers an asynchronous + // illegal memory access (the crash disappears under CUDA_LAUNCH_BLOCKING=1). + auto shuffled_data = raft::make_device_matrix( + res, vpq_dset.data.extent(0), vpq_dset.data.extent(1)); + raft::matrix::gather(res, + raft::make_const_mdspan(vpq_dset.data.view()), + raft::make_const_mdspan(dev_perm_i64.view()), + shuffled_data.view()); + vpq_dset.data = std::move(shuffled_data); + + // Store perm as IdxT for graph unshuffling later + // perm[shuffled_idx] = original_idx + // This is used for: + // 1. Remapping neighbor values: neighbor j (shuffled) -> perm[j] (original) + // 2. Reordering rows: row i (for shuffled node i) -> position perm[i] (original node) + dev_perm = raft::make_device_vector(res, n_rows); + cast_to_idx_op cast_op; + thrust::transform(raft::resource::get_thrust_policy(res), + dev_perm_i64.data_handle(), + dev_perm_i64.data_handle() + n_rows, + dev_perm.data_handle(), + cast_op); + + dataset_shuffled = true; + + auto shuffle_end = std::chrono::high_resolution_clock::now(); + auto shuffle_ms = + std::chrono::duration_cast(shuffle_end - shuffle_start).count(); + RAFT_LOG_INFO("# Dataset shuffle time: %.3lf sec", (double)shuffle_ms / 1000); + } + + idx_opt.emplace(res, params.metric); + // Use the (optionally shuffled) compressed dataset built above. + idx_opt->update_dataset(res, std::move(vpq_dset)); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_ms = std::chrono::duration_cast(end - start).count(); + RAFT_LOG_INFO("# VPQ compression time: %.3lf sec", (double)elapsed_ms / 1000); + + // Free the original dataset -- queries will be reconstructed from VPQ codes. + dev_aligned_dataset.reset(); + RAFT_LOG_INFO( + "# Freed original dataset from device (%.1f MiB); queries will use VPQ reconstruction", + to_mib(final_graph_size * dataset_dim * sizeof(T))); + } while (true) { auto start = std::chrono::high_resolution_clock::now(); auto curr_query_size = std::min(2 * curr_graph_size, final_graph_size); auto next_graph_degree = small_graph_degree; if (curr_graph_size == final_graph_size) { next_graph_degree = graph_degree; } + RAFT_LOG_INFO("Current graph size %lu: # current graph degree = %lu", + (uint64_t)curr_graph_size, + (uint64_t)next_graph_degree); // The search count (topk) is set to the next graph degree + 1, because // pruning is not used except in the last iteration. // (*) The appropriate setting for itopk_size requires careful consideration. - auto curr_topk = next_graph_degree + 1; - auto curr_itopk_size = next_graph_degree + 32; + auto curr_topk = next_graph_degree + 1; + // The configurable itopk (iter_params.itopk_size, 0 = auto) applies only to the true growing + // iterations, where the degree being built is small_graph_degree. When the graph reaches its + // full size the search builds a graph_degree-degree graph (topk = graph_degree + 1); that + // iteration needs a larger itopk, so it overrides the configured value with the auto formula. + // The final iteration (flag_last) uses a fixed itopk tied to the output topk. + auto curr_itopk_size = + (iter_params.itopk_size > 0 && next_graph_degree == small_graph_degree) + ? (uint64_t)iter_params.itopk_size + : std::max(next_graph_degree + 32, (uint64_t)128); if (flag_last) { curr_topk = topk; curr_itopk_size = curr_topk + 32; @@ -2095,71 +2379,135 @@ auto iterative_build_graph( (uint64_t)curr_itopk_size, (uint64_t)curr_topk); - cuvs::neighbors::cagra::search_params search_params; - search_params.algo = cuvs::neighbors::cagra::search_algo::AUTO; - search_params.max_queries = max_chunk_size; - search_params.itopk_size = curr_itopk_size; - - // Create an index (idx), a query view (dev_query_view), and a mdarray for - // search results (neighbors). - auto dev_dataset_view = raft::make_device_matrix_view( - dev_dataset.data_handle(), (int64_t)curr_graph_size, dev_dataset.extent(1)); - - auto idx = index( - res, params.metric, dev_dataset_view, raft::make_const_mdspan(cagra_graph.view())); - - auto dev_query_view = raft::make_device_matrix_view( - dev_dataset.data_handle(), (int64_t)curr_query_size, dev_dataset.extent(1)); - - auto neighbors_view = - raft::make_host_matrix_view(neighbors_ptr, curr_query_size, curr_topk); - - // Search. - // Since there are many queries, divide them into batches and search them. - auto query_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( - res, - dev_query_view.data_handle(), - static_cast(curr_query_size), - static_cast(dev_query_view.extent(1)), - max_chunk_size, - raft::resource::get_cuda_stream(res), - raft::resource::get_workspace_resource_ref(res)); - for (const auto& batch : query_batch) { - auto batch_dev_query_view = raft::make_device_matrix_view( - batch.data(), batch.size(), dev_query_view.extent(1)); - auto batch_dev_neighbors_view = raft::make_device_matrix_view( - dev_neighbors.data_handle(), batch.size(), curr_topk); - auto batch_dev_distances_view = raft::make_device_matrix_view( - dev_distances.data_handle(), batch.size(), curr_topk); - - cuvs::neighbors::cagra::search(res, - search_params, - idx, - batch_dev_query_view, - batch_dev_neighbors_view, - batch_dev_distances_view); - - auto batch_neighbors_view = raft::make_host_matrix_view( - neighbors_view.data_handle() + batch.offset() * curr_topk, batch.size(), curr_topk); - raft::copy(res, batch_neighbors_view, batch_dev_neighbors_view); + cuvs::neighbors::cagra::search_params search_params = iter_params; + search_params.max_queries = max_chunk_size; + search_params.itopk_size = curr_itopk_size; + + // Create index and query views. + if (!build_compression.has_value()) { + auto dev_dataset_view = raft::make_device_matrix_view( + dev_dataset.data_handle(), (int64_t)curr_graph_size, dev_dataset.extent(1)); + if (use_device_graph) { + idx_opt.emplace( + res, params.metric, dev_dataset_view, raft::make_const_mdspan(dev_graph.view())); + } else { + idx_opt.emplace( + res, params.metric, dev_dataset_view, raft::make_const_mdspan(cagra_graph.view())); + } + } else { + if (use_device_graph) { + idx_opt->update_graph(res, raft::make_const_mdspan(dev_graph.view())); + } else { + idx_opt->update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + } } - - // Optimize graph - auto next_graph_size = curr_query_size; - cagra_graph = raft::make_host_matrix(0, 0); // delete existing grahp - cagra_graph = raft::make_host_matrix(next_graph_size, next_graph_degree); - optimize( - res, neighbors_view, cagra_graph.view(), flag_last ? params.guarantee_connectivity : 0); + const auto& idx = *idx_opt; + + // When compression is enabled, reconstruct queries from VPQ codes instead of + // reading from the (freed) original dataset. + auto dev_reconstructed_queries = + build_compression.has_value() + ? raft::make_device_matrix(res, curr_query_size, dataset_dim) + : raft::make_device_matrix(res, 0, 0); + if (build_compression.has_value()) { + auto* vpq_dset = dynamic_cast*>(&idx.data()); + RAFT_EXPECTS(vpq_dset != nullptr, "Expected VPQ dataset in compressed index"); + reconstruct_vpq_queries( + res, *vpq_dset, 0, curr_query_size, dev_reconstructed_queries.view()); + } + auto dev_query_view = + build_compression.has_value() + ? raft::make_device_matrix_view( + dev_reconstructed_queries.data_handle(), (int64_t)curr_query_size, dataset_dim) + : raft::make_device_matrix_view( + dev_dataset.data_handle(), (int64_t)curr_query_size, dev_dataset.extent(1)); + + auto dev_optimized_graph = raft::make_device_matrix(res, 0, 0); + + search_and_optimize(res, + search_params, + idx, + dev_query_view, + dev_neighbors.view(), + dev_distances.view(), + dev_optimized_graph, + curr_query_size, + next_graph_degree, + curr_topk, + max_chunk_size); + + dev_graph = std::move(dev_optimized_graph); + use_device_graph = true; auto end = std::chrono::high_resolution_clock::now(); auto elapsed_ms = std::chrono::duration_cast(end - start).count(); RAFT_LOG_DEBUG("# elapsed time: %.3lf sec", (double)elapsed_ms / 1000); if (flag_last) { break; } - flag_last = (curr_graph_size == final_graph_size); - curr_graph_size = next_graph_size; + flag_last = (curr_graph_size == final_graph_size); + auto next_graph_size = curr_query_size; + curr_graph_size = next_graph_size; } + // TODO: when build_compression matches params.compression, the dataset is compressed twice + // (once for the build loop and once in build()'s shared tail). We could avoid this by returning + // the index directly (with its VPQ dataset and device-side graph) instead of just the host graph. + auto stream = raft::resource::get_cuda_stream(res); + + // If the dataset was shuffled, we need to unshuffle the graph: + // Recall: perm[shuffled_idx] = original_idx (stored in dev_perm) + // 1. Remap neighbor indices from shuffled space to original space + // 2. Reorder rows from shuffled order to original order + if (dataset_shuffled) { + auto unshuffle_start = std::chrono::high_resolution_clock::now(); + RAFT_LOG_INFO("Unshuffling graph to restore original dataset ordering..."); + + const auto n_rows = dev_graph.extent(0); + const auto degree = dev_graph.extent(1); + + // Step 1: Remap all neighbor indices using perm + // graph[i][j] contains shuffled index j; we need original index = perm[j] + remap_indices_op remap_op{dev_perm.data_handle()}; + thrust::transform(raft::resource::get_thrust_policy(res), + dev_graph.data_handle(), + dev_graph.data_handle() + n_rows * degree, + dev_graph.data_handle(), + remap_op); + + // Step 2: Reorder rows back to original order + // Row i in dev_graph is for shuffled node i, which is original node perm[i]. + // We want this row to be at position perm[i] in the final graph. + // scatter: output[map[i]] = input[i], so map[i] = perm[i] + auto dev_unshuffled_graph = raft::make_device_matrix(res, n_rows, degree); + + // Use thrust::scatter to reorder: for each row i, place it at position perm[i] + // We scatter row-by-row conceptually, but do it element-wise with computed output indices + graph_scatter_index_op scatter_idx_op{dev_perm.data_handle(), degree}; + auto output_indices = + thrust::make_transform_iterator(thrust::make_counting_iterator(0), scatter_idx_op); + + thrust::scatter(raft::resource::get_thrust_policy(res), + dev_graph.data_handle(), + dev_graph.data_handle() + n_rows * degree, + output_indices, + dev_unshuffled_graph.data_handle()); + + dev_graph = std::move(dev_unshuffled_graph); + + auto unshuffle_end = std::chrono::high_resolution_clock::now(); + auto unshuffle_ms = + std::chrono::duration_cast(unshuffle_end - unshuffle_start) + .count(); + RAFT_LOG_INFO("# Graph unshuffle time: %.3lf sec", (double)unshuffle_ms / 1000); + } + + cagra_graph = raft::make_host_matrix(dev_graph.extent(0), dev_graph.extent(1)); + raft::copy(cagra_graph.data_handle(), + dev_graph.data_handle(), + dev_graph.extent(0) * dev_graph.extent(1), + stream); + raft::resource::sync_stream(res); + return cagra_graph; } diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index bca8d3314d..dc6fc15266 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -73,11 +73,13 @@ void search_main_core( topk, queries.extent(1)); + RAFT_LOG_DEBUG("search_main_core: creating plan with max_node_id=%u", params.max_node_id); using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; std::unique_ptr< search_plan_impl> plan = factory::create( res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk); + RAFT_LOG_DEBUG("search_main_core: plan created, plan->max_node_id=%u", plan->max_node_id); plan->check(topk); @@ -153,7 +155,12 @@ void search_main(raft::resources const& res, // Dispatch search parameters based on the dataset kind. if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { + if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) { + RAFT_LOG_WARN("In this search mode, smem_dtype supports only F16. Set it to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } // Search using a plain (strided) row-major dataset + RAFT_LOG_DEBUG("Searching with strided dataset"); RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded || index.dataset_norms().has_value(), "Dataset norms must be provided for CosineExpanded metric"); @@ -180,6 +187,13 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { + RAFT_LOG_DEBUG("Searching with VPQ dataset"); + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + RAFT_LOG_WARN( + "CAGRA VPQ E5M2 smem_dtype requires native FP8 support on SM90+. Falling back to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); search_main_core( diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 75b56860bb..7957ca9123 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -202,11 +202,12 @@ struct dataset_descriptor_host { uint32_t team_size = 0; // JIT LTO metadata - stored when descriptor is created - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; - uint32_t dataset_block_dim = 0; - bool is_vpq = false; - uint32_t pq_bits = 0; - uint32_t pq_len = 0; + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; + uint32_t dataset_block_dim = 0; + bool is_vpq = false; + uint32_t pq_bits = 0; + uint32_t pq_len = 0; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; // Codebook type is determined by DataT for VPQ (always half for now) struct state { @@ -221,11 +222,15 @@ struct dataset_descriptor_host { template state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)} { + // RAFT_LOG_INFO("trying to create a descriptor state %p", + // reinterpret_cast(this)); } ~state() noexcept { if (std::holds_alternative(value)) { + // RAFT_LOG_INFO("trying to free descriptor state %p", + // reinterpret_cast(this)); auto& [ptr, stream] = std::get(value); RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream)); } @@ -258,7 +263,9 @@ struct dataset_descriptor_host { uint32_t dataset_block_dim_val, bool is_vpq_val = false, uint32_t pq_bits_val = 0, - uint32_t pq_len_val = 0) + uint32_t pq_len_val = 0, + cuvs::neighbors::cagra::internal_dtype smem_dtype_val = + cuvs::neighbors::cagra::internal_dtype::F16) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()}, @@ -266,7 +273,8 @@ struct dataset_descriptor_host { dataset_block_dim{dataset_block_dim_val}, is_vpq{is_vpq_val}, pq_bits{pq_bits_val}, - pq_len{pq_len_val} + pq_len{pq_len_val}, + smem_dtype{smem_dtype_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index 6992ae979a..f73f901f95 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -6,6 +6,7 @@ #pragma once #include "compute_distance_vpq.hpp" +#include "packed_type.hpp" #include #include @@ -14,6 +15,30 @@ namespace cuvs::neighbors::cagra::detail { +template +struct vpq_smem_value_config; + +template +struct vpq_smem_value_config< + PQ_LEN, + SmemDType, + std::enable_if_t> { + using smem_val_pack_t = half2; + using smem_val_t = half; + using smem_val_pack_uint_t = uint32_t; + static constexpr uint32_t num_packed_elements = 2; +}; + +template +struct vpq_smem_value_config> { + using smem_val_pack_t = device::fp8xN; + using smem_val_t = typename smem_val_pack_t::unit_t; + using smem_val_pack_uint_t = typename smem_val_pack_t::uint_t; + static constexpr uint32_t num_packed_elements = smem_val_pack_t::num_elements; +}; + template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; @@ -38,6 +64,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t, "Only CODE_BOOK_T = `half` is supported now"); @@ -80,8 +107,11 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t; + static constexpr std::uint32_t kSMemCodeBookSizeInBytes = - (1 << PQ_BITS) * PQ_LEN * utils::size_of(); + (1 << PQ_BITS) * PQ_LEN * utils::size_of() / + smem_val_config::num_packed_elements; _RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, std::uint32_t encoded_dataset_dim, @@ -108,7 +138,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); + raft::round_up_safe(dim, DatasetBlockDim) * + utils::size_of() / + smem_val_config::num_packed_elements; } private: @@ -122,7 +154,8 @@ template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> RAFT_KERNEL __launch_bounds__(1, 1) vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t* out, const std::uint8_t* encoded_dataset_ptr, @@ -140,7 +173,8 @@ RAFT_KERNEL __launch_bounds__(1, 1) DataT, IndexT, DistanceT, - half>; + half, + SmemDType>; new (out) desc_type( encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim); } @@ -153,7 +187,8 @@ template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> dataset_descriptor_host vpq_descriptor_spec::init_(const cagra::search_params& params, + DistanceT, + SmemDType>::init_(const cagra::search_params& params, const std::uint8_t* encoded_dataset_ptr, uint32_t encoded_dataset_dim, const CodebookT* vq_code_book_ptr, @@ -179,7 +215,8 @@ vpq_descriptor_spec; + half, + SmemDType>; return host_type{ desc_type{ @@ -194,7 +231,8 @@ vpq_descriptor_spec<<<1, 1, 0, stream>>>(dev_ptr, + DistanceT, + SmemDType><<<1, 1, 0, stream>>>(dev_ptr, encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, @@ -205,9 +243,10 @@ vpq_descriptor_spec +#include #include namespace cuvs::neighbors::cagra::detail { +inline auto select_supported_vpq_smem_dtype(const cagra::search_params& params) + -> cuvs::neighbors::cagra::internal_dtype +{ + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + return cuvs::neighbors::cagra::internal_dtype::F16; + } + return params.smem_dtype; +} + template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> struct vpq_descriptor_spec : public instance_spec { using base_type = instance_spec; using typename base_type::data_type; @@ -69,6 +81,13 @@ struct vpq_descriptor_spec : public instance_spec { // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } + if (select_supported_vpq_smem_dtype(params) != SmemDType) { return -1.0; } + // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to + // use the expanded team_size / dataset_block_dim grid. + constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; + if (params.team_size == 0 && DatasetBlockDim != TeamSize * auto_dataset_block_dim_per_team) { + return -1.0; + } // Otherwise, favor the closest dataset dimensionality. constexpr std::uint32_t preferred_load_elmes_per_thread = 16; /*magic number that is good based on experiments.*/ diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in index c159da3229..25d4732a34 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in @@ -13,6 +13,7 @@ constexpr uint32_t team_size = @team_size@; constexpr uint32_t dim = @dim@; constexpr uint32_t pq_bits = @pq_bits@; constexpr uint32_t pq_len = @pq_len@; +constexpr auto smem_dtype = @smem_dtype@; using codebook_t = @codebook_type@; using data_t = @data_type@; using index_t = @index_type@; @@ -30,6 +31,7 @@ template struct vpq_descriptor_spec; + distance_t, + smem_dtype>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json index cf6e060d33..1241b2346c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json @@ -36,15 +36,53 @@ "_mxdim_team": [ { "dim": "128", - "team_size": "8" + "team_size": "8", + "pq_len": "2" }, { "dim": "256", - "team_size": "16" + "team_size": "16", + "pq_len": "2" }, { "dim": "512", - "team_size": "32" + "team_size": "32", + "pq_len": "2" + }, + { + "dim": "128", + "team_size": "8", + "pq_len": "4" + }, + { + "dim": "256", + "team_size": "16", + "pq_len": "4" + }, + { + "dim": "512", + "team_size": "32", + "pq_len": "4" + }, + { + "dim": "128", + "team_size": "4", + "pq_len": "8" + }, + { + "dim": "256", + "team_size": "8", + "pq_len": "8" + }, + { + "dim": "512", + "team_size": "16", + "pq_len": "8" + }, + { + "dim": "1024", + "team_size": "32", + "pq_len": "8" } ], "_codebook": [ @@ -53,7 +91,20 @@ "codebook_abbrev": "h" } ], - "pq_bits": ["8"], - "pq_len": ["2", "4"], - "metric": ["L2Expanded"] + "pq_bits": [ + "8" + ], + "metric": [ + "L2Expanded" + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } + ] } diff --git a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp index cc164994ea..1bcf6f8fbd 100644 --- a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp +++ b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp @@ -54,6 +54,11 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); } +RAFT_DEVICE_INLINE_FUNCTION void lds(uint64_t& x, uint32_t addr) +{ + asm volatile("ld.shared.u64 {%0}, [%1];" : "=l"(x) : "r"(addr)); +} + RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, const uint32_t* addr) { lds(x, uint32_t(__cvta_generic_to_shared(addr))); @@ -71,6 +76,16 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, const uint4* addr) lds(x, uint32_t(__cvta_generic_to_shared(addr))); } +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint32_t& x) +{ + asm volatile("st.shared.u32 [%0], %1;" : : "r"(addr), "r"(reinterpret_cast(x))); +} + +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint64_t& x) +{ + asm volatile("st.shared.u64 [%0], %1;" : : "r"(addr), "l"(reinterpret_cast(x))); +} + RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x) { asm volatile("st.shared.v2.u16 [%0], {%1, %2};" diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 26cd13bab8..a1e2f6be9c 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -87,6 +87,7 @@ struct key { uint32_t extra_val; // this one has different meanings for different descriptor types uint32_t team_size; uint32_t metric; + uint32_t smem_dtype; }; template @@ -100,7 +101,8 @@ auto make_key(const cagra::search_params& params, dataset.dim(), dataset.stride(), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } template @@ -114,20 +116,22 @@ auto make_key(const cagra::search_params& params, dataset.dim(), uint32_t(reinterpret_cast(dataset.pq_code_book.data_handle()) >> 6), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } inline auto operator==(const key& a, const key& b) -> bool { return a.data_ptr == b.data_ptr && a.n_rows == b.n_rows && a.dim == b.dim && - a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric; + a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric && + a.smem_dtype == b.smem_dtype; } struct key_hash { inline auto operator()(const key& x) const noexcept -> std::size_t { return size_t{x.data_ptr} + size_t{x.n_rows} * size_t{x.dim} * size_t{x.extra_val} + - (size_t{x.team_size} ^ size_t{x.metric}); + (size_t{x.team_size} ^ size_t{x.metric}) + size_t{x.smem_dtype}; } }; diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 896817d869..9b7e44bd8e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -59,10 +59,14 @@ std::shared_ptr build_single_cta_launcher( persistent); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -105,10 +109,14 @@ std::shared_ptr build_multi_cta_launcher( dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -151,10 +159,14 @@ std::shared_ptr build_multi_kernel_launcher( dataset_desc.pq_bits, dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 0c3ed64d13..8539efd004 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -52,11 +52,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ (`tag_codebook_half`): JIT matrix fixes `pq_bits=8`; only `pq_len` is selected at runtime. @@ -64,30 +66,35 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_setup_workspace_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + cuvs::neighbors::cagra::internal_dtype smem_dtype) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } /// Registers dist_op + normalization + `compute_distance` for standard layout. @@ -108,11 +115,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ: only the `compute_distance` fragment (no standard dist_op / normalization in this path). @@ -120,33 +129,77 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_compute_distance_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + cuvs::neighbors::cagra::internal_dtype smem_dtype) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } private: + template + static void dispatch_cagra_smem_dtype(cuvs::neighbors::cagra::internal_dtype smem_dtype, + Lambda&& l) + { + switch (smem_dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: + std::forward(l).template operator()(); + return; + case cuvs::neighbors::cagra::internal_dtype::E5M2: + std::forward(l).template operator()(); + return; + default: break; + } + RAFT_FAIL("Unsupported CAGRA JIT smem_dtype: %u", static_cast(smem_dtype)); + } + + template + static void dispatch_cagra_vpq_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + uint32_t pq_len, + Lambda&& l) + { + switch (pq_len) { + case 2: + dispatch_cagra_vpq_pq2_4_team_dim<2u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 4: + dispatch_cagra_vpq_pq2_4_team_dim<4u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 8: + dispatch_cagra_vpq_pq8_team_dim(team_size, dataset_block_dim, std::forward(l)); + return; + default: break; + } + RAFT_FAIL("CAGRA JIT VPQ expects pq_len in {2,4,8}; got %u", static_cast(pq_len)); + } + void add_dist_op_device_function(cuvs::distance::DistanceType metric) { // dist_op_matrix.json pairs tag_metric_hamming with uint8 query (tag_u8) only; L2/IP/L1 use @@ -191,15 +244,16 @@ struct CagraPlannerBase : AlgorithmPlanner { uint32_t dataset_block_dim) { auto go = [&]() { - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&]() { - this->add_static_fragment>(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&]() { + this->add_static_fragment>(); + }); }; // tag_u8 is only used for BitwiseHamming query layout; cosine norm fragments are built for // float query tag. Use if constexpr so we do not instantiate tag_norm_cosine with tag_u8 @@ -218,7 +272,9 @@ struct CagraPlannerBase : AlgorithmPlanner { // template parameters; CAGRA reads team_size / dataset_block_dim from the host descriptor at // planning time. template - static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) + static void dispatch_cagra_standard_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) { switch (team_size) { case 8: @@ -247,11 +303,88 @@ struct CagraPlannerBase : AlgorithmPlanner { break; default: break; } - RAFT_FAIL("Unsupported team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + RAFT_FAIL("Unsupported standard team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", static_cast(team_size), static_cast(dataset_block_dim)); } + template + static void dispatch_cagra_vpq_pq2_4_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<8u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<8u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<16u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<16u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<32u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<32u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=%u team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(PqLenV), + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_pq8_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 4: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<4u, 128u, 8u, 8u>(); return; + default: break; + } + break; + case 8: + switch (dataset_block_dim) { + case 256: std::forward(l).template operator()<8u, 256u, 8u, 8u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 512: std::forward(l).template operator()<16u, 512u, 8u, 8u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 1024: std::forward(l).template operator()<32u, 1024u, 8u, 8u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=8 team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + void add_sample_filter_device_function(std::unique_ptr udf_fragment = nullptr) { if constexpr (std::is_same_v) { diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index 92a014bd2f..a6dd0495fd 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -100,10 +100,17 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; + constexpr auto SmemDType = DescriptorT::kSmemDType; + using PQ_CODEBOOK_LOAD_T = uint32_t; + + using smem_val_config = vpq_smem_value_config; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr uint32_t num_packed_elements = smem_val_config::num_packed_elements; const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes; static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment."); - constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** + constexpr uint32_t vlen = utils::size_of() / utils::size_of(); constexpr uint32_t nelem = raft::div_rounding_up_unsafe(DatasetBlockDim / PQ_LEN, TeamSize * vlen); @@ -115,12 +122,12 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( DISTANCE_T norm = 0; for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim; elem_offset += DatasetBlockDim / PQ_LEN) { - uint32_t pq_codes[nelem]; + PQ_CODEBOOK_LOAD_T pq_codes[nelem]; #pragma unroll for (std::uint32_t e = 0; e < nelem; e++) { const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; if (k >= n_subspace) break; - device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); + device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); } // if constexpr (PQ_LEN % 2 == 0) { @@ -135,24 +142,63 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( if (d >= dim) break; device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); } - std::uint32_t pq_code = pq_codes[e]; + PQ_CODEBOOK_LOAD_T pq_code = pq_codes[e]; #pragma unroll for (std::uint32_t v = 0; v < vlen; v++) { if (PQ_LEN * (v + k) >= dim) break; #pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN / 2; m++) { - constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); - const std::uint32_t d1 = m + (PQ_LEN / 2) * v; - const std::uint32_t d = - d1 * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; - half2 q2, c2; - device::lds(q2, query_ptr + sizeof(half2) * d); - device::lds(c2, - pq_codebook_ptr + - sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * m + (2 * (pq_code & 0xff)))); - auto dist = q2 - c2 - reinterpret_cast(vq_vals)[d1]; - dist = dist * dist; - norm += static_cast(dist.x + dist.y); + for (std::uint32_t m = 0; m < PQ_LEN / num_packed_elements; m++) { + constexpr uint32_t vq_val_pack_num_elements = 2; + constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); + std::uint32_t vq_half2_index = + m * (num_packed_elements / vq_val_pack_num_elements) + (PQ_LEN / 2) * v; + + uint32_t query_val_index; + if constexpr (num_packed_elements == 2) { + query_val_index = + vq_half2_index * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + } else if constexpr (PQ_LEN == num_packed_elements) { + query_val_index = elem_offset + v * (DatasetBlockDim / (num_packed_elements * vlen)) + + e * TeamSize + laneId; + } else { + const uint32_t query_vec_element_id = + (elem_offset + e * vlen * TeamSize + v + laneId * vlen) * PQ_LEN / + num_packed_elements; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + query_val_index = + transpose(query_vec_element_id); + } + + if constexpr (num_packed_elements == 2) { + smem_val_pack_t q2, c2; + device::lds(q2, query_ptr + sizeof(smem_val_pack_t) * query_val_index); + device::lds(c2, + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t q_vec, c_vec; + device::lds(q_vec.as_uint(), + query_ptr + sizeof(smem_val_pack_uint_t) * query_val_index); + device::lds(c_vec.as_uint(), + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + + half2 q2, c2; +#pragma unroll + for (uint32_t bi = 0; bi < num_packed_elements / 2; bi++) { + q2 = q_vec.as_half2(bi); + c2 = c_vec.as_half2(bi); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + vq_half2_index += 1; + } + } } pq_code >>= 8; } @@ -219,7 +265,8 @@ template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ DistanceT compute_distance_impl( const typename dataset_descriptor_base_t::args_t args, IndexT dataset_index) @@ -238,7 +285,8 @@ __device__ DistanceT compute_distance_impl( DataT, IndexT, DistanceT, - QueryT>; + QueryT, + SmemDType>; return compute_distance_vpq_impl(args, dataset_index); } else { static_assert(sizeof(TeamSize) == 0, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in index 13cd022918..1856781391 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -11,6 +11,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -38,7 +39,8 @@ __device__ distance_t compute_distance(const args_t data_t, index_t, distance_t, - query_t>(args, dataset_index) + query_t, + k_smem_dtype>(args, dataset_index) : distance_t{}; return device::team_sum(per_thread, team_size_bits); } @@ -55,7 +57,8 @@ compute_distance_per_thread(const args_t args, inde data_t, index_t, distance_t, - query_t>(args, dataset_index); + query_t, + k_smem_dtype>(args, dataset_index); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json index 82b8dbdf4e..4d260c5507 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,96 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh index 8cdd7febd5..494e0973fe 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -79,12 +79,17 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( const typename DescriptorT::DATA_T* queries_ptr, uint32_t query_id) -> const DescriptorT* { - using QUERY_T = typename DescriptorT::QUERY_T; - using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; - using word_type = uint32_t; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto PQ_BITS = DescriptorT::kPqBits; - constexpr auto PQ_LEN = DescriptorT::kPqLen; + using QUERY_T = typename DescriptorT::QUERY_T; + using word_type = uint32_t; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto PQ_BITS = DescriptorT::kPqBits; + constexpr auto PQ_LEN = DescriptorT::kPqLen; + constexpr auto SmemDType = DescriptorT::kSmemDType; + using smem_val_config = vpq_smem_value_config; + using smem_val_t = typename smem_val_config::smem_val_t; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr auto num_packed_elements = smem_val_config::num_packed_elements; auto* r = reinterpret_cast(smem_ptr); @@ -105,18 +110,32 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( } __syncthreads(); - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = r->pq_code_book_ptr()[i]; - buf2.y = r->pq_code_book_ptr()[i + 1]; - - constexpr auto num_elements_per_bank = 4 / utils::size_of(); - constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; - const auto j = i / num_elements_per_bank; - const auto smem_index = - (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - - device::sts(codebook_buf + smem_index * sizeof(half2), buf2); + for (unsigned i = threadIdx.x * num_packed_elements; i < (1 << PQ_BITS) * PQ_LEN; + i += blockDim.x * num_packed_elements) { + constexpr auto num_elements_per_bank = + num_packed_elements / (utils::size_of() / utils::size_of()); + + if constexpr (PQ_LEN >= num_elements_per_bank) { + constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; + const auto j = i / num_elements_per_bank; + const auto smem_index = + (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); + + if constexpr (num_packed_elements == 2) { + smem_val_pack_t buf; + buf.x = r->pq_code_book_ptr()[i]; + buf.y = r->pq_code_book_ptr()[i + 1]; + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_t), buf); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t buf; +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = + static_cast(static_cast(r->pq_code_book_ptr()[i + k])); + } + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_uint_t), buf.as_uint()); + } + } } } @@ -125,19 +144,33 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( constexpr cuvs::spatial::knn::detail::utils::mapping mapping{}; auto smem_query_ptr = - reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + - DescriptorT::kSMemCodeBookSizeInBytes); - for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { - half2 buf2{0, 0}; - if (i < dim) { buf2.x = mapping(queries_ptr[i]); } - if (i + 1 < dim) { buf2.y = mapping(queries_ptr[i + 1]); } - if constexpr ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { + reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + + DescriptorT::kSMemCodeBookSizeInBytes); + for (unsigned i = threadIdx.x * num_packed_elements; i < dim; + i += blockDim.x * num_packed_elements) { + smem_val_pack_t buf; + if constexpr (num_packed_elements == 2) { + buf.x = 0; + buf.y = 0; + if (i < dim) { buf.x = mapping(queries_ptr[i]); } + if (i + 1 < dim) { buf.y = mapping(queries_ptr[i + 1]); } + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = static_cast(0.0f); + if (i + k < dim) { + buf.data.x1[k] = static_cast(static_cast(mapping(queries_ptr[i + k]))); + } + } + } + if constexpr ((PQ_BITS == 8) && (PQ_LEN % num_packed_elements == 0)) { constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** - constexpr auto kStride = vlen * PQ_LEN / 2; - reinterpret_cast(smem_query_ptr)[transpose(i / 2)] = - buf2; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + reinterpret_cast( + smem_query_ptr)[transpose( + i / num_packed_elements)] = buf; } else { - (reinterpret_cast(smem_query_ptr + i))[0] = buf2; + (reinterpret_cast(smem_query_ptr + i))[0] = buf; } } @@ -152,7 +185,8 @@ template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ const dataset_descriptor_base_t* setup_workspace_impl( const dataset_descriptor_base_t* desc_ptr, void* smem, @@ -176,7 +210,8 @@ __device__ const dataset_descriptor_base_t* setup_work DataT, IndexT, DistanceT, - QueryT>; + QueryT, + SmemDType>; const desc_t* desc = static_cast(desc_ptr); const desc_t* result = setup_workspace_vpq_impl(desc, smem, queries, query_id); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in index fa17705250..6a54c9f956 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -12,6 +12,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -39,7 +40,8 @@ setup_workspace( data_t, index_t, distance_t, - query_t>(desc, smem, queries, query_id); + query_t, + k_smem_dtype>(desc, smem, queries, query_id); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json index 83aa8764bc..567fe3e5a1 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,96 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp new file mode 100644 index 0000000000..f52edc126b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once +#include +#include + +#include +#include + +namespace cuvs::neighbors::cagra::detail::device { +template +struct uintN_t {}; +template <> +struct uintN_t<32> { + using type = uint32_t; +}; +template <> +struct uintN_t<64> { + using type = uint64_t; +}; + +template +struct fp8xN {}; + +template +struct fp8xN { + using uint_t = typename uintN_t<8 * NumPacked>::type; + using unit_t = __nv_fp8_e5m2; + using x2_t = __nv_fp8x2_storage_t; + static constexpr uint32_t num_elements = NumPacked; + + union { + unit_t x1[num_elements]; + x2_t x2[num_elements / 2]; + uint_t u; + } data; + + HDI fp8xN() { data.u = 0; } + + HDI uint_t& as_uint() { return data.u; } + HDI uint_t as_uint() const { return data.u; } + HDI half2 as_half2(const uint32_t i) const + { + return __nv_cvt_fp8x2_to_halfraw2(data.x2[i], __NV_E5M2); + } +}; +} // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index e09ef82a39..f1c7305833 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -511,7 +511,11 @@ struct search hashmap.data(), hash_bitlen, stream, - static_cast(this->dataset_size)); + // Bound random seed selection to the graph size, not the dataset size. + // During iterative / CAGRA-Q build the graph is smaller than the dataset, + // so using dataset_size here selects seeds that index past the graph end + // (out-of-bounds access). See https://github.com/rapidsai/cuvs/pull/1780. + static_cast(graph.extent(0))); std::shared_ptr compute_distance_to_child_nodes_launcher = make_cagra_multi_kernel_jit_launcher +bool is_ptr_device_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; +} + +template +bool is_ptr_host_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.hostPointer != nullptr; +} + /** * Utility to sync memory from a host_matrix_view to a device_matrix_view * @@ -301,4 +317,394 @@ void copy_with_padding( } } +/** + * Utility to create a batched device view from a host view + * + * This utility will create a batched device view from a host view and will handle the prefetch and + * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() + * function + * + * Usage: + * ``` + * batched_device_view_from_host view(res, host_view, batch_size, host_writeback, + * initialize); while (view.next_view().extent(0) > 0) { auto device_view = view.next_view(); + * // use device_view + * } + * ``` + * + * The call to next_view() will + * * synchronize on all previous operations / increments batch_id_ + * * (optionally) write back the data of the previous batch to the host + * * (optionally) prefetch the data of the next batch + * * return the view of the current batch + * + * @tparam T The type of the data + * @tparam IdxT The type of the index + */ +template +class batched_device_view_from_host { + public: + enum class memory_strategy { + device_only, // data is on device only (no copy needed) + copy_device, // data is explicitly moved to/from device buffers + managed_only, // data is on managed memory (system managed) + }; + + /** + * Create a batched device view from a host view and will handle the prefetch and + * writeback of the data. Each batch can be referenced exactly once by calling the next_view() + * method. + * + * @param res The resources to use + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param host_writeback Whether to write back the data to the host (only for host memory) + * (default: false) + * @param initialize Whether to initialize the data (only for managed memory) (default: true) + */ + batched_device_view_from_host(raft::resources const& res, + raft::host_matrix_view host_view, + uint64_t batch_size, + bool host_writeback = false, + bool initialize = true) + : res_(res), + host_view_(host_view), + batch_size_(batch_size), + offset_(0), + batch_id_(-2), + num_buffers_(2), + host_writeback_(host_writeback), + initialize_(initialize) + { + if (host_view.extent(0) == 0) { + mem_strategy_ = memory_strategy::device_only; + return; + } + + RAFT_EXPECTS(host_writeback_ || initialize_, + "At least one of host_writeback or initialize must be true"); + + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); + switch (attr_.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; + case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; + } + + RAFT_LOG_DEBUG("Memory strategy: %d for type %d, size %zu", + static_cast(mem_strategy_), + static_cast(attr_.type), + host_view.extent(0) * host_view.extent(1) * sizeof(T)); + + // buffer allocations + if (mem_strategy_ == memory_strategy::copy_device) { + try { + device_mem_[0].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource_ref(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource_ref(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && initialize_ && + batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource_ref(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } catch (std::bad_alloc& e) { + if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); + mem_strategy_ = memory_strategy::managed_only; + } else { + throw std::bad_alloc(); + } + } catch (raft::logic_error& e) { + if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG( + "Insufficient memory for device buffers (logic error), switching to managed memory"); + mem_strategy_ = memory_strategy::managed_only; + } else { + throw raft::logic_error("Insufficient memory for device buffers (logic error)"); + } + } + } + + // setup stream pool if not already present + size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < required_streams) { + // always create at least 2 streams to account for subsequent iterator calls + raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); + } + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + + // if data is managed and not for_write_ we can set the attribute on the device ptr + if (mem_strategy_ == memory_strategy::managed_only) { + location_.type = cudaMemLocationTypeDevice; + location_.id = static_cast(raft::resource::get_device_id(res_)); + if (!host_writeback_) { + advise_read_mostly(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T)); + // TODO maybe also reset upon destruction + } + } + + // prefetch next batch (0) + prefetch_next_batch(); + } + + ~batched_device_view_from_host() noexcept + { + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch stream & res + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); + writeback_stream_.synchronize(); + } + break; + case memory_strategy::copy_device: + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } + writeback_stream_.synchronize(); + } + break; + case memory_strategy::device_only: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch. If writeback is enabled, the last active batch will be written back to + * the host. + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= + static_cast(host_view_.extent(0)); + + // special case for empty host view or last batch surpassed + if (end_of_data) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } + + // trigger prefetch of next batch (also increments batch_id_) + prefetch_next_batch(); + + uint32_t current_pos = batch_id_ % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + /** + * Prefetch the next batch + * + * This function will prefetch the next batch and will handle the writeback of the data. + * + * @return True if the next batch exists, false otherwise + */ + bool prefetch_next_batch() + { + batch_id_++; + + // ensure previous batch at position batch_id_ is ready + if (initialize_) { prefetch_stream_.synchronize(); } + if (host_writeback_) { writeback_stream_.synchronize(); } + + // this step will + // * write back data from batch_id_ - 1 + // * prefetch data for batch_id_ + 1 + + // if data is on host and host_writeback_ is true we will have to copy it back + // if data is on host and initialize_ is true we will have to copy it to the device_ptr + + // if data is managed and !host_writeback_ we can discard the data from device memory + // if data is managed and initialize_ is true we can prefetch it to the device + // if data is managed and !initialize_ we can discard and prefetch the data location + + // if data is on device only this is almost a noop, just prepping the pointers + + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); + + bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); + + if (next_batch_exists) { + // synchronize to ensure all previous operations are completed + // in particular all work on batch_id_ - 1 + raft::resource::sync_stream(res_); + + int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; + actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); + + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_ && batch_id_ > 1) { + uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; + size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); + discard_managed_region(device_ptr[discard_pos], discard_size); + } + // prefetch next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + prefetch_managed_region( + device_ptr[prefetch_pos], + actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); + break; + case memory_strategy::copy_device: + if (host_writeback_ && batch_id_ > 0) { + // copy back last active + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + if (initialize_) { + // prefetch next position + prefetch_from_host_to_device( + device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); + } + + break; + case memory_strategy::device_only: + // just move pointer to next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + break; + } + + offset_ += actual_batch_size_[prefetch_pos]; + } + + return next_batch_exists; + } + + void advise_read_mostly(T* ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#endif + } + + void discard_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY(cudaMemDiscardBatchAsync(dptrs, sizes, 1, 0, writeback_stream_)); +#endif + // FIXME: CUDA12 does not support discard + } + + void prefetch_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY( + cudaMemDiscardAndPrefetchBatchAsync(dptrs, sizes, 1, location_, 0, prefetch_stream_)); + } +#else + // FIXME: CUDA12 does not support discard - so we just prefetch + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } +#endif + } + + void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) + { + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaMemcpyHostToDevice, + prefetch_stream_)); + } + + void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) + { + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + n_bytes, + cudaMemcpyDeviceToHost, + writeback_stream_)); + } + + // stream pool for local streams + std::optional> local_stream_pool_; + rmm::cuda_stream_view prefetch_stream_; + rmm::cuda_stream_view writeback_stream_; + + // configuration + memory_strategy mem_strategy_; + const raft::resources& res_; + bool initialize_; // initialize the data on the device + bool host_writeback_; // write back the data to the host + + // batch position information + uint64_t batch_size_; + int32_t batch_id_; + uint64_t offset_; + + cudaMemLocation location_; + + // input pointer information + raft::host_matrix_view host_view_; + cudaPointerAttributes attr_; + + // internal device buffers + uint64_t num_buffers_; + std::optional> device_mem_[3]; + T* device_ptr[3]; + uint32_t actual_batch_size_[3]; +}; + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9b96f94bf0..35506b440f 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -176,6 +176,7 @@ ConfigureTest( neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu neighbors/ann_cagra/bug_iterative_cagra_build.cu neighbors/ann_cagra/bug_issue_93_reproducer.cu + neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu GPUS 1 PERCENT 100 ) @@ -196,7 +197,9 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST - PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batch_load_iterator.cu + PATH neighbors/ann_cagra/test_optimize_uint32_t.cu + neighbors/ann_cagra/test_batched_device_view_from_host.cu + neighbors/ann_cagra/test_batch_load_iterator.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index a6704f892a..fc9a1212c6 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -6,6 +6,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include "vpq_utils.cuh" #include #include "naive_knn.cuh" @@ -275,6 +276,7 @@ struct AnnCagraInputs { std::optional non_owning_memory_buffer_flag = std::nullopt; cuvs::neighbors::MergeStrategy merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; }; inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) @@ -298,6 +300,13 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) {search_algo::AUTO, "auto"} // }; std::vector build_algo = {"IVF_PQ", "NN_DESCENT", "ITERATIVE_CAGRA_SEARCH", "AUTO"}; + const auto smem_dtype_str = [](cuvs::neighbors::cagra::internal_dtype dtype) { + switch (dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: return "F16"; + case cuvs::neighbors::cagra::internal_dtype::E5M2: return "E5M2"; + } + return "Unknown"; + }; std::vector merge_strategy = {"PHYSICAL", "LOGICAL"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo_name[p.algo] << ", max_queries=" << p.max_queries @@ -311,7 +320,7 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) if (p.compression.has_value()) { auto vpq = p.compression.value(); os << ", pq_bits=" << vpq.pq_bits << ", pq_dim=" << vpq.pq_dim - << ", vq_n_centers=" << vpq.vq_n_centers; + << ", vq_n_centers=" << vpq.vq_n_centers << ", smem_dtype=" << smem_dtype_str(p.smem_dtype); } os << '}' << std::endl; return os; @@ -414,6 +423,7 @@ class AnnCagraTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.smem_dtype = ps.smem_dtype; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -461,6 +471,48 @@ class AnnCagraTest : public ::testing::TestWithParam { raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); + + reference_recall = 1; + if (ps.compression.has_value()) { + auto decoded_dataset = + raft::make_device_matrix(handle_, ps.n_rows, ps.dim); + + using VpqMathT = half; + cuvs::neighbors::decode_vpq_dataset( + decoded_dataset.view(), + dynamic_cast&>(index.data()), + raft::resource::get_cuda_stream(handle_)); + auto indices_out_view = raft::make_device_matrix_view( + indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cuvs::neighbors::naive_knn(handle_, + dists_out_view.data_handle(), + indices_out_view.data_handle(), + search_queries.data(), + decoded_dataset.data_handle(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric); + std::vector indices_vpq_dataset(queries_size); + std::vector distances_vpq_dataset(queries_size); + raft::update_host( + distances_vpq_dataset.data(), dists_out_view.data_handle(), queries_size, stream_); + raft::update_host( + indices_vpq_dataset.data(), indices_out_view.data_handle(), queries_size, stream_); + + reference_recall = std::get<1>(calc_recall(indices_naive, + indices_vpq_dataset, + distances_naive, + distances_vpq_dataset, + ps.n_queries, + ps.k, + 0)); + printf("reference_recall = %e\n", reference_recall); + } } // for (int i = 0; i < min(ps.n_queries, 10); i++) { @@ -470,7 +522,7 @@ class AnnCagraTest : public ::testing::TestWithParam { // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); // } - double min_recall = ps.min_recall; + double min_recall = ps.min_recall * reference_recall; EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -519,6 +571,7 @@ class AnnCagraTest : public ::testing::TestWithParam { AnnCagraInputs ps; rmm::device_uvector database; rmm::device_uvector search_queries; + double reference_recall; }; template @@ -1494,38 +1547,38 @@ inline std::vector generate_inputs() {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - // Corner cases for small datasets - inputs2 = raft::util::itertools::product( - {2}, - {3, 6, 31, 32, 64, 101}, - {1, 10}, - {2}, // k - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {0}, // query size - {0}, - {256}, - {1}, - {cuvs::distance::DistanceType::L2Expanded}, - {false}, - {true}, - {true}, - {0.995}, - {std::optional{std::nullopt}}, - {std::optional{std::nullopt}}, - {std::optional{std::nullopt}}, - {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, - cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + // // Corner cases for small datasets + // inputs2 = raft::util::itertools::product( + // {2}, + // {3, 6, 31, 32, 64, 101}, + // {1, 10}, + // {2}, // k + // {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + // {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + // {0}, // query size + // {0}, + // {256}, + // {1}, + // {cuvs::distance::DistanceType::L2Expanded}, + // {false}, + // {true}, + // {true}, + // {0.995}, + // {std::optional{std::nullopt}}, + // {std::optional{std::nullopt}}, + // {std::optional{std::nullopt}}, + // {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, + // cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); + // inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); // Varying dim and build algo. inputs2 = raft::util::itertools::product( {100}, - {1000}, - {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 1024}, // dim - {16}, // k - {graph_build_algo::IVF_PQ, - graph_build_algo::NN_DESCENT, + {1000000}, + {768}, // dim + {16}, // k + { // graph_build_algo::IVF_PQ, + // graph_build_algo::NN_DESCENT, graph_build_algo::ITERATIVE_CAGRA_SEARCH}, {search_algo::AUTO}, {10}, @@ -1539,7 +1592,7 @@ inline std::vector generate_inputs() {false}, {true}, {false}, - {0.995}, + {0.01}, {std::optional{std::nullopt}}, {std::optional{std::nullopt}}, {std::optional{std::nullopt}}, @@ -1604,29 +1657,29 @@ inline std::vector generate_inputs() {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - // Varying n_rows, host_dataset - inputs2 = raft::util::itertools::product( - {100}, - {10000}, - {32}, - {10}, - {graph_build_algo::AUTO}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, - {false, true}, - {false}, - {true}, - {0.985}, - {std::optional{std::nullopt}}, - {std::optional{std::nullopt}}, - {std::optional{std::nullopt}}, - {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, - cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + // // Varying n_rows, host_dataset + // inputs2 = raft::util::itertools::product( + // {100}, + // {10000}, + // {32}, + // {10}, + // {graph_build_algo::AUTO}, + // {search_algo::AUTO}, + // {10}, + // {0}, // team_size + // {64}, + // {1}, + // {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, + // {false, true}, + // {false}, + // {true}, + // {0.985}, + // {std::optional{std::nullopt}}, + // {std::optional{std::nullopt}}, + // {std::optional{std::nullopt}}, + // {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, + // cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); + // inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); // A few PQ configurations. // Varying dim, vq_n_centers @@ -1652,14 +1705,18 @@ inline std::vector generate_inputs() {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); // don't demand high recall // without refinement - for (uint32_t pq_len : {2}) { // for now, only pq_len = 2 is supported, more options coming soon + for (uint32_t pq_len : {2, 4, 8}) { for (uint32_t vq_n_centers : {100, 1000}) { - for (auto input : inputs2) { - vpq_params ps{}; - ps.pq_dim = input.dim / pq_len; - ps.vq_n_centers = vq_n_centers; - input.compression.emplace(ps); - inputs.push_back(input); + for (auto internal_smem_dtype : {cuvs::neighbors::cagra::internal_dtype::E5M2, + cuvs::neighbors::cagra::internal_dtype::F16}) { + for (auto input : inputs2) { + vpq_params ps{}; + ps.pq_dim = input.dim / pq_len; + ps.vq_n_centers = vq_n_centers; + input.compression.emplace(ps); + input.smem_dtype = internal_smem_dtype; + inputs.push_back(input); + } } } } diff --git a/cpp/tests/neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu b/cpp/tests/neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu index adeb774a8b..b06c1cba92 100644 --- a/cpp/tests/neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu +++ b/cpp/tests/neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu @@ -38,8 +38,8 @@ class cagra_graph_smaller_than_dataset_test : public ::testing::Test { protected: void run() { - // Create a dataset with 1000 points - constexpr int64_t n_dataset = 1000; + // Create a dataset with 10000 points + constexpr int64_t n_dataset = 10000; constexpr int64_t n_dim = 128; constexpr int64_t n_queries = 100; constexpr int64_t k = 10; @@ -63,9 +63,9 @@ class cagra_graph_smaller_than_dataset_test : public ::testing::Test { // Recreate the bug scenario: LARGE dataset, SMALL graph // (like iterative_build_graph does in intermediate iterations) - constexpr int64_t n_graph = n_dataset / 2; // Only 500 nodes in graph + constexpr int64_t n_graph = n_dataset / 2; // Only 5000 nodes in graph - // Step 1: Build index on SMALL subset (500 points) + // Step 1: Build index on SMALL subset (5000 points) auto small_dataset_view = raft::make_device_matrix_view( dataset.data_handle(), n_graph, n_dim); @@ -74,13 +74,13 @@ class cagra_graph_smaller_than_dataset_test : public ::testing::Test { auto small_index = cagra::build(res, small_index_params, small_dataset_view); raft::resource::sync_stream(res); - // Step 2: Update to FULL dataset (1000 points) but keep small graph (500 nodes) - // This creates the exact bug scenario: dataset.size=1000, graph.extent(0)=500 + // Step 2: Update to FULL dataset (10000 points) but keep small graph (5000 nodes) + // This creates the exact bug scenario: dataset.size=10000, graph.extent(0)=5000 small_index.update_dataset(res, raft::make_const_mdspan(dataset.view())); // Verify the mismatch - THIS IS THE BUG SCENARIO! - ASSERT_EQ(small_index.graph().extent(0), n_graph); // Graph has 500 nodes - ASSERT_EQ(small_index.size(), n_dataset); // Dataset has 1000 points + ASSERT_EQ(small_index.graph().extent(0), n_graph); // Graph has 5000 nodes + ASSERT_EQ(small_index.size(), n_dataset); // Dataset has 10000 points ASSERT_NE(small_index.graph().extent(0), small_index.size()); // Mismatch! // Create queries @@ -100,8 +100,8 @@ class cagra_graph_smaller_than_dataset_test : public ::testing::Test { search_params.algo = cagra::search_algo::SINGLE_CTA; // THIS SHOULD NOT CRASH OR CAUSE OOB ACCESS - // Before fix: random seeds use dataset.size (1000) -> tries to access graph[700] -> CRASH! - // After fix: random seeds use graph.extent(0) (500) -> only accesses graph[0-499] -> SAFE! + // Before fix: random seeds use dataset.size (10000) -> tries to access graph[7000] -> CRASH! + // After fix: random seeds use graph.extent(0) (5000) -> only accesses graph[0-4999] -> SAFE! cagra::search(res, search_params, small_index, diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu new file mode 100644 index 0000000000..1e1cc13093 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu @@ -0,0 +1,205 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../src/neighbors/detail/cagra/utils.hpp" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { + +using IdxT = uint32_t; + +struct BatchConfig { + bool initialize; + bool host_writeback; +}; + +struct DimsConfig { + int64_t n_rows; + int64_t n_cols; + uint64_t batch_size; +}; + +class BatchedDeviceViewFromHostTest : public ::testing::Test { + protected: + void SetUp() override { raft::resource::sync_stream(res); } + + /** + * Run batched_device_view_from_host over host data, copy device views back, + * and verify against the input. + */ + template + void run_and_verify_batched(InputMatrixView input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) + { + int64_t n_rows = input_view.extent(0); + int64_t n_cols = input_view.extent(1); + + std::vector readback(n_rows * n_cols); + + int64_t total_processed = 0; + + { + cagra::detail::batched_device_view_from_host batched( + res, + raft::make_host_matrix_view(input_view.data_handle(), n_rows, n_cols), + batch_size, + host_writeback, + initialize); + while (true) { + auto dev_view = batched.next_view(); + if (dev_view.extent(0) == 0) break; + + if (initialize) { + raft::copy(readback.data() + total_processed * n_cols, + dev_view.data_handle(), + dev_view.extent(0) * dev_view.extent(1), + raft::resource::get_cuda_stream(res)); + } + if (host_writeback) { raft::matrix::fill(res, dev_view, IdxT(17)); } + total_processed += dev_view.extent(0); + } + } + raft::resource::sync_stream(res); + + EXPECT_EQ(total_processed, n_rows); + if (initialize) { + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch (initialize) at index " << i; + } + } + if (host_writeback) { + auto readback_view = + raft::make_host_matrix_view(readback.data(), n_rows, n_cols); + raft::copy(res, readback_view, input_view); + raft::resource::sync_stream(res); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(17)) << "Mismatch (host_writeback) at index " << i; + } + } + } + + raft::resources res; +}; + +TEST_F(BatchedDeviceViewFromHostTest, EmptyView) +{ + auto host_empty = raft::make_host_matrix(0, 8); + auto host_view = host_empty.view(); + cagra::detail::batched_device_view_from_host batched( + res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); + + auto view = batched.next_view(); + EXPECT_EQ(view.extent(0), 0); + EXPECT_EQ(view.extent(1), 8); + EXPECT_EQ(view.data_handle(), nullptr); +} + +using BatchDimsParam = std::tuple; + +class BatchedDeviceViewFromHostParameterizedTest + : public BatchedDeviceViewFromHostTest, + public ::testing::WithParamInterface {}; + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + std::vector host_data(n_rows * n_cols); + auto host_view = raft::make_host_matrix_view(host_data.data(), n_rows, n_cols); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, PinnedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, ManagedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + raft::matrix::fill(res, host_view, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +static const std::array kBatchConfigs = {{ + {/*initialize=*/true, /*host_writeback=*/false}, + {/*initialize=*/false, /*host_writeback=*/true}, + {/*initialize=*/true, /*host_writeback=*/true}, +}}; + +static const std::array kDimsConfigs = {{ + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/256}, // rows less than batch size, single batch + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/64}, // single batch + {/*n_rows=*/256, /*n_cols=*/32, /*batch_size=*/32}, // multiple batches + {/*n_rows=*/500, + /*n_cols=*/32, + /*batch_size=*/128}, // multiple batches, partial batch in the end +}}; + +INSTANTIATE_TEST_SUITE_P(BatchConfigs, + BatchedDeviceViewFromHostParameterizedTest, + ::testing::Combine(::testing::ValuesIn(kBatchConfigs), + ::testing::ValuesIn(kDimsConfigs))); + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/tests/neighbors/ann_utils.cuh b/cpp/tests/neighbors/ann_utils.cuh index cbc95d7bb7..91449ba56f 100644 --- a/cpp/tests/neighbors/ann_utils.cuh +++ b/cpp/tests/neighbors/ann_utils.cuh @@ -127,10 +127,10 @@ struct idx_dist_pair { /** Calculate recall value using only neighbor indices */ template -auto calc_recall(const std::vector& expected_idx, - const std::vector& actual_idx, - size_t rows, - size_t cols) +std::tuple calc_recall(const std::vector& expected_idx, + const std::vector& actual_idx, + size_t rows, + size_t cols) { size_t match_count = 0; size_t total_count = static_cast(rows) * static_cast(cols); @@ -219,16 +219,17 @@ auto eval_recall(const std::vector& expected_idx, /** Overload of calc_recall to account for distances */ template -auto calc_recall(const std::vector& expected_idx, - const std::vector& actual_idx, - const std::vector& expected_dist, - const std::vector& actual_dist, - size_t rows, - size_t cols, - double eps) +std::tuple calc_recall(const std::vector& expected_idx, + const std::vector& actual_idx, + const std::vector& expected_dist, + const std::vector& actual_dist, + size_t rows, + size_t cols, + double eps) { - size_t match_count = 0; - size_t total_count = static_cast(rows) * static_cast(cols); + size_t match_count = 0; + size_t index_match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); for (size_t i = 0; i < rows; ++i) { for (size_t k = 0; k < cols; ++k) { size_t idx_k = i * cols + k; // row major assumption! @@ -247,8 +248,28 @@ auto calc_recall(const std::vector& expected_idx, } } } - return std::make_tuple( - static_cast(match_count) / static_cast(total_count), match_count, total_count); + + // Index based recall + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + + if (act_idx == exp_idx) { + index_match_count++; + break; + } + } + } + } + + return std::make_tuple(static_cast(match_count) / static_cast(total_count), + static_cast(index_match_count) / static_cast(total_count), + match_count, + total_count); } /** same as eval_recall, but in case indices do not match, @@ -265,7 +286,7 @@ auto eval_neighbours(const std::vector& expected_idx, bool test_unique = true, size_t max_duplicates = 0) -> testing::AssertionResult { - auto [actual_recall, match_count, total_count] = + auto [actual_recall, index_based_actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); diff --git a/cpp/tests/neighbors/vpq_utils.cuh b/cpp/tests/neighbors/vpq_utils.cuh new file mode 100644 index 0000000000..44b4d188ee --- /dev/null +++ b/cpp/tests/neighbors/vpq_utils.cuh @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include + +#include +#include + +namespace cuvs::neighbors { +template +__global__ void decode_vpq_dataset_kernel(data_t* const decoded_dataset_ptr, + const uint32_t ldd, + const math_t* const vq_codebook_ptr, + const uint32_t ldv, + const math_t* const pq_codebook_ptr, + const uint32_t pq_subspace_dim, + const uint32_t pq_table_size, + const uint32_t dataset_dim, + const size_t dataset_size, + const uint8_t* const data_ptr, + const uint32_t ldi) +{ + constexpr uint32_t warp_size = 32; + const size_t batch_id = (blockIdx.x * blockDim.x + threadIdx.x) / warp_size; + if (batch_id >= dataset_size) { return; } + + const auto local_data_ptr = data_ptr + ldi * batch_id; + const auto vq_code = *reinterpret_cast(local_data_ptr); + const auto pq_code_ptr = local_data_ptr + sizeof(uint32_t); + const auto vq_vec_ptr = vq_codebook_ptr + vq_code * ldv; + auto local_dst_ptr = decoded_dataset_ptr + batch_id * ldd; + + const auto lane_id = threadIdx.x % warp_size; + for (uint32_t i = lane_id; i < dataset_dim; i += warp_size) { + const auto pq_code = pq_code_ptr[i / pq_subspace_dim]; + const auto pq_v = pq_codebook_ptr[pq_code * pq_subspace_dim + (i % pq_subspace_dim)]; + + local_dst_ptr[i] = static_cast(vq_vec_ptr[i]) + static_cast(pq_v); + } +} + +template +void decode_vpq_dataset(raft::device_matrix_view decoded_dataset, + const cuvs::neighbors::vpq_dataset& vpq_dataset, + cudaStream_t cuda_stream) +{ + const auto dataset_size = decoded_dataset.extent(0); + RAFT_EXPECTS(vpq_dataset.data.extent(0) == dataset_size, "Dataset sizes mismatch"); + RAFT_EXPECTS(vpq_dataset.pq_bits() == 8, + "decode_vpq_dataset currently only supports pq_bits == 8 (got %u)", + vpq_dataset.pq_bits()); + + constexpr uint32_t block_size = 256; + constexpr uint32_t warp_size = 32; + constexpr int64_t vecs_per_cta = block_size / warp_size; + const auto grid_size = raft::div_rounding_up_safe(decoded_dataset.extent(0), vecs_per_cta); + + decode_vpq_dataset_kernel + <<>>(decoded_dataset.data_handle(), + decoded_dataset.stride(0), + vpq_dataset.vq_code_book.data_handle(), + vpq_dataset.vq_code_book.stride(0), + vpq_dataset.pq_code_book.data_handle(), + vpq_dataset.pq_len(), + 1u << vpq_dataset.pq_bits(), + vpq_dataset.dim(), + dataset_size, + vpq_dataset.data.data_handle(), + vpq_dataset.data.stride(0)); +} +} // namespace cuvs::neighbors diff --git a/python/cuvs_bench/cuvs_bench/run/__main__.py b/python/cuvs_bench/cuvs_bench/run/__main__.py index 6950ff7202..58d1b604bd 100644 --- a/python/cuvs_bench/cuvs_bench/run/__main__.py +++ b/python/cuvs_bench/cuvs_bench/run/__main__.py @@ -5,6 +5,7 @@ import json import os +import warnings from pathlib import Path from typing import Optional @@ -257,6 +258,13 @@ def main( and any backend-specific connection parameters (host, port, etc.). """ + warnings.warn( + "The 'cuvs_bench.run' CLI is deprecated and will be removed in a future release. " + "Use BenchmarkOrchestrator from cuvs_bench.orchestrator instead.", + FutureWarning, + stacklevel=2, + ) + if not data_export: # Determine backend type and extra kwargs from --backend-config backend_type = "cpp_gbench"