diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 08c51862bf07..74136b601b81 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -19,7 +19,8 @@ Improvements Optimizations --------------------- -(No changes) +* Improved performance of HNSW graph construction by using bulk vector + scoring in NeighborArray diversity checks. (PR#15667) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index bd38843dc687..e113f4e544f4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -71,6 +71,7 @@ public class HnswGraphBuilder implements HnswBuilder { private final int[] bulkScoreNodes; // for bulk scoring private final float[] bulkScores; // for bulk scoring + private final float[] diversityScratch; private final SplittableRandom random; protected final UpdateableRandomVectorScorer scorer; protected final HnswGraphSearcher graphSearcher; @@ -185,6 +186,7 @@ protected HnswGraphBuilder( // but enough to take advantage of bulk scoring this.bulkScoreNodes = new int[MAX_BULK_SCORE_NODES]; this.bulkScores = new float[MAX_BULK_SCORE_NODES]; + this.diversityScratch = new float[M * 2]; entryCandidates = new GraphBuilderKnnCollector(1); beamCandidates = new GraphBuilderKnnCollector(beamWidth); beamCandidates0 = new GraphBuilderKnnCollector(Math.min(beamWidth / 2, M * 3)); @@ -439,7 +441,7 @@ private void updateNeighbor( if (nbrsOfNbr.nodes()[j] == node) return; } } - nbrsOfNbr.addAndEnsureDiversity(node, score, nbr, scorer); + nbrsOfNbr.addAndEnsureDiversity(node, score, nbr, scorer, diversityScratch); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index c43cd681440a..c205d79bc780 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -119,13 +119,18 @@ private void alertOnHeapMemoryUsageChange(int newLength, int previousLength) { public void addAndEnsureDiversity( int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer) throws IOException { + addAndEnsureDiversity(newNode, newScore, nodeId, scorer, null); + } + + public void addAndEnsureDiversity( + int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer, float[] scratch) + throws IOException { addOutOfOrder(newNode, newScore); if (size < maxSize) { return; } - // we're oversize, need to do diversity check and pop out the least diverse neighbour scorer.setScoringOrdinal(nodeId); - removeIndex(findWorstNonDiverse(scorer)); + removeIndex(findWorstNonDiverse(scorer, scratch)); assert size == maxSize - 1; } @@ -284,17 +289,17 @@ private int descSortFindRightMostInsertionPoint(float newScore, int bound) { * Find first non-diverse neighbour among the list of neighbors starting from the most distant * neighbours */ - private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOException { + private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer, float[] scratch) + throws IOException { int[] uncheckedIndexes = sort(scorer); assert uncheckedIndexes != null : "We will always have something unchecked"; int uncheckedCursor = uncheckedIndexes.length - 1; for (int i = size - 1; i > 0; i--) { if (uncheckedCursor < 0) { - // no unchecked node left break; } scorer.setScoringOrdinal(nodes.get(i)); - if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer)) { + if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer, scratch)) { return i; } if (i == uncheckedIndexes[uncheckedCursor]) { @@ -305,26 +310,47 @@ private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOEx } private boolean isWorstNonDiverse( - int candidateIndex, int[] uncheckedIndexes, int uncheckedCursor, RandomVectorScorer scorer) + int candidateIndex, + int[] uncheckedIndexes, + int uncheckedCursor, + RandomVectorScorer scorer, + float[] scratch) throws IOException { float minAcceptedSimilarity = scores.get(candidateIndex); if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { - // the candidate itself is unchecked - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes.get(i)); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { + int numNodesToCheck = candidateIndex; + if (numNodesToCheck == 0) { + return false; + } + + float[] neighborScores = scratch != null ? scratch : new float[numNodesToCheck]; + float maxScore = scorer.bulkScore(nodes.buffer, neighborScores, numNodesToCheck); + if (maxScore < minAcceptedSimilarity) { + return false; + } + + for (int i = 0; i < numNodesToCheck; i++) { + if (neighborScores[i] >= minAcceptedSimilarity) { return true; } } } else { - // else we just need to make sure candidate does not violate diversity with the (newly - // inserted) unchecked nodes assert candidateIndex > uncheckedIndexes[uncheckedCursor]; - for (int i = uncheckedCursor; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes.get(uncheckedIndexes[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { + int numNodesToCheck = uncheckedCursor + 1; + + float[] neighborScores = scratch != null ? scratch : new float[numNodesToCheck]; + int[] nodesToCheck = new int[numNodesToCheck]; + for (int i = 0; i <= uncheckedCursor; i++) { + nodesToCheck[i] = nodes.get(uncheckedIndexes[i]); + } + + float maxScore = scorer.bulkScore(nodesToCheck, neighborScores, numNodesToCheck); + if (maxScore < minAcceptedSimilarity) { + return false; + } + + for (int i = 0; i < numNodesToCheck; i++) { + if (neighborScores[i] >= minAcceptedSimilarity) { return true; } }