From f9073c46f5f4edb29291b20c498fd5628815f1da Mon Sep 17 00:00:00 2001 From: ML-dev-crypto Date: Fri, 6 Feb 2026 12:13:22 +0530 Subject: [PATCH] =?UTF-8?q?=EF=BB=BFUse=20bulk=20scoring=20with=20scratch?= =?UTF-8?q?=20buffer=20for=20HNSW=20diversity=20checks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reuse a scratch buffer during HNSW diversity checks and use RandomVectorScorer.bulkScore with early termination to avoid per-call allocations on the hot path. --- lucene/CHANGES.txt | 3 +- .../lucene/util/hnsw/HnswGraphBuilder.java | 4 +- .../lucene/util/hnsw/NeighborArray.java | 60 +++++++++++++------ 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 08c51862bf07..74136b601b81 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -19,7 +19,8 @@ Improvements Optimizations --------------------- -(No changes) +* Improved performance of HNSW graph construction by using bulk vector + scoring in NeighborArray diversity checks. (PR#15667) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index bd38843dc687..e113f4e544f4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -71,6 +71,7 @@ public class HnswGraphBuilder implements HnswBuilder { private final int[] bulkScoreNodes; // for bulk scoring private final float[] bulkScores; // for bulk scoring + private final float[] diversityScratch; private final SplittableRandom random; protected final UpdateableRandomVectorScorer scorer; protected final HnswGraphSearcher graphSearcher; @@ -185,6 +186,7 @@ protected HnswGraphBuilder( // but enough to take advantage of bulk scoring this.bulkScoreNodes = new int[MAX_BULK_SCORE_NODES]; this.bulkScores = new float[MAX_BULK_SCORE_NODES]; + this.diversityScratch = new float[M * 2]; entryCandidates = new GraphBuilderKnnCollector(1); beamCandidates = new GraphBuilderKnnCollector(beamWidth); beamCandidates0 = new GraphBuilderKnnCollector(Math.min(beamWidth / 2, M * 3)); @@ -439,7 +441,7 @@ private void updateNeighbor( if (nbrsOfNbr.nodes()[j] == node) return; } } - nbrsOfNbr.addAndEnsureDiversity(node, score, nbr, scorer); + nbrsOfNbr.addAndEnsureDiversity(node, score, nbr, scorer, diversityScratch); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index c43cd681440a..c205d79bc780 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -119,13 +119,18 @@ private void alertOnHeapMemoryUsageChange(int newLength, int previousLength) { public void addAndEnsureDiversity( int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer) throws IOException { + addAndEnsureDiversity(newNode, newScore, nodeId, scorer, null); + } + + public void addAndEnsureDiversity( + int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer, float[] scratch) + throws IOException { addOutOfOrder(newNode, newScore); if (size < maxSize) { return; } - // we're oversize, need to do diversity check and pop out the least diverse neighbour scorer.setScoringOrdinal(nodeId); - removeIndex(findWorstNonDiverse(scorer)); + removeIndex(findWorstNonDiverse(scorer, scratch)); assert size == maxSize - 1; } @@ -284,17 +289,17 @@ private int descSortFindRightMostInsertionPoint(float newScore, int bound) { * Find first non-diverse neighbour among the list of neighbors starting from the most distant * neighbours */ - private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOException { + private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer, float[] scratch) + throws IOException { int[] uncheckedIndexes = sort(scorer); assert uncheckedIndexes != null : "We will always have something unchecked"; int uncheckedCursor = uncheckedIndexes.length - 1; for (int i = size - 1; i > 0; i--) { if (uncheckedCursor < 0) { - // no unchecked node left break; } scorer.setScoringOrdinal(nodes.get(i)); - if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer)) { + if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer, scratch)) { return i; } if (i == uncheckedIndexes[uncheckedCursor]) { @@ -305,26 +310,47 @@ private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOEx } private boolean isWorstNonDiverse( - int candidateIndex, int[] uncheckedIndexes, int uncheckedCursor, RandomVectorScorer scorer) + int candidateIndex, + int[] uncheckedIndexes, + int uncheckedCursor, + RandomVectorScorer scorer, + float[] scratch) throws IOException { float minAcceptedSimilarity = scores.get(candidateIndex); if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { - // the candidate itself is unchecked - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes.get(i)); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { + int numNodesToCheck = candidateIndex; + if (numNodesToCheck == 0) { + return false; + } + + float[] neighborScores = scratch != null ? scratch : new float[numNodesToCheck]; + float maxScore = scorer.bulkScore(nodes.buffer, neighborScores, numNodesToCheck); + if (maxScore < minAcceptedSimilarity) { + return false; + } + + for (int i = 0; i < numNodesToCheck; i++) { + if (neighborScores[i] >= minAcceptedSimilarity) { return true; } } } else { - // else we just need to make sure candidate does not violate diversity with the (newly - // inserted) unchecked nodes assert candidateIndex > uncheckedIndexes[uncheckedCursor]; - for (int i = uncheckedCursor; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes.get(uncheckedIndexes[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { + int numNodesToCheck = uncheckedCursor + 1; + + float[] neighborScores = scratch != null ? scratch : new float[numNodesToCheck]; + int[] nodesToCheck = new int[numNodesToCheck]; + for (int i = 0; i <= uncheckedCursor; i++) { + nodesToCheck[i] = nodes.get(uncheckedIndexes[i]); + } + + float maxScore = scorer.bulkScore(nodesToCheck, neighborScores, numNodesToCheck); + if (maxScore < minAcceptedSimilarity) { + return false; + } + + for (int i = 0; i < numNodesToCheck; i++) { + if (neighborScores[i] >= minAcceptedSimilarity) { return true; } }