Skip to content

[REVIEW] Add L1 support to NN-Descent#1898

Open
yan-zaretskiy wants to merge 5 commits intorapidsai:mainfrom
yan-zaretskiy:nn-descent-l1
Open

[REVIEW] Add L1 support to NN-Descent#1898
yan-zaretskiy wants to merge 5 commits intorapidsai:mainfrom
yan-zaretskiy:nn-descent-l1

Conversation

@yan-zaretskiy
Copy link
Contributor

Use the SIMT local-join path for L1 distance and keep the existing WMMA path for dot-product metrics. This preserves existing behavior while extending NN-Descent to support L1.

Use the SIMT local-join path for L1 distance and keep the existing WMMA path for
dot-product metrics. This preserves existing behavior while extending NN-Descent
to support L1.

(cherry picked from commit 000d69f67dea4ae694c9cbeb4193f1ff5900d034)
@yan-zaretskiy yan-zaretskiy requested a review from a team as a code owner March 9, 2026 18:49
@yan-zaretskiy yan-zaretskiy requested a review from jinsolp March 9, 2026 18:50
@yan-zaretskiy yan-zaretskiy added non-breaking Introduces a non-breaking change feature request New feature or request labels Mar 9, 2026
Copy link
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yan-zaretskiy ! Leaving some questions and suggestions below:

Copy link
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@divyegala divyegala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I'm finding the code complicated to follow now. Please consider a way to clarify when the wmma kernel is used and when the non_wmma kernel is used (meaning for what types and distances).

// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
// is 1024 and 1536 respectively, which means the bounds don't work anymore
template <typename Index_t, typename ID_t = InternalID_t<Index_t>, typename DistEpilogue_t>
template <typename DistanceOp_t,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass the DistanceOp_t as a runtime argument? We are trying to be cognizant of our binary sizes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you mean. But a runtime argument would imply some form of a virtual dispatch, while this is kernel code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could just pass the distance enum, and if-else on that like the other kernels do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, not sure why I was overthinking it. Will update.

case cuvs::distance::DistanceType::CosineExpanded:
case cuvs::distance::DistanceType::L2Expanded:
case cuvs::distance::DistanceType::L2SqrtExpanded:
local_join_kernel<dot_dist_op>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting hard for me to follow. Can you rename the one that does not use wmma instructions to local_join_kernel_no_wmma?

@yan-zaretskiy
Copy link
Contributor Author

In general, I'm finding the code complicated to follow now. Please consider a way to clarify when the wmma kernel is used and when the non_wmma kernel is used (meaning for what types and distances).

I simplified the local_join. Now there are 3 calls:

  • F32 -> SIMT branch
  • else if L1 distance (with any dype) -> SIMT branch
  • else WMMA

@cjnolet
Copy link
Member

cjnolet commented Mar 13, 2026

Yan, can you please update the doxygen for the public CAGRA and nn-descent functions? Need to say that nn-descent supports L1 and CAGRA supports L1 when using the nn-descent graph_build_algo.

Copy link
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small doc nitpick

* The following distance metrics are supported:
* - L2
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
* - L1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* - L1
* - L1 (currently only supported with NN-Descent and Iterative Search as the build algorithm)

Copy link
Member

@divyegala divyegala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-emptively approving with request for some comments for added context.

__device__ __forceinline__ float operator()(float a, float b) const
{
if (metric == cuvs::distance::DistanceType::L1) { return raft::abs(a - b); }
return a * b;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment for this fallback mentioning that it is a reusable calculation for IP, cosine, and L2?

l2_norms_.data_handle(),
build_config_.metric,
dist_epilogue);
local_join_kernel_simt<<<nrow_, BLOCK_SIZE, 0, stream>>>(graph_.h_graph_new.data_handle(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a comment here and in the else path detailing what conditions lead to what kernel to be launched - I know you responded in a GH comment, it would just be nice to have that context in the code :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Development

Successfully merging this pull request may close these issues.

4 participants