diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 4912aa5e01f9..f54010b2a37d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -113,6 +113,8 @@ New Features * GITHUB#15818: Add BM25 k3 query-term frequency saturation to BM25Similarity. (Sagar Upadhyaya) +* GITHUB#16076: Implement 3-ary priority queue (Vijay) + Improvements --------------------- * GITHUB#15704: Replace LinkedList with more efficient data structure. (Renato Haeberli) diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java b/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java index 104b438a7218..a7394b3bc710 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java @@ -29,6 +29,8 @@ */ public class FieldValueHitQueue extends PriorityQueue { + private static final int HEAP_ARITY = 3; + /** Extension of ScoreDoc to also store the {@link FieldComparator} slot. */ public static class Entry extends ScoreDoc { public int slot; @@ -112,7 +114,7 @@ public boolean lessThan(Entry hitA, Entry hitB) { // prevent instantiation and extension. private FieldValueHitQueue(SortField[] fields, int size, EntryLessThan lessThan) { - super(size, lessThan); + super(size, HEAP_ARITY, lessThan); // When we get here, fields.length is guaranteed to be > 0, therefore no // need to check it again. this.fields = fields; diff --git a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java index c93e12dec746..c22d579a7567 100644 --- a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java @@ -73,11 +73,20 @@ public static PriorityQueue usingComparator( protected int size = 0; private final int maxSize; private final T[] heap; + private final int arity; private final LessThan lessThan; /** Create an empty priority queue of the configured size using the specified {@link LessThan}. */ public PriorityQueue(int maxSize, LessThan lessThan) { - this(maxSize, lessThan, () -> null); + this(maxSize, 2, lessThan, () -> null); + } + + /** + * Create an empty priority queue of the configured size and heap arity using the specified {@link + * LessThan}. + */ + protected PriorityQueue(int maxSize, int arity, LessThan lessThan) { + this(maxSize, arity, lessThan, () -> null); } /** @@ -112,8 +121,20 @@ public PriorityQueue(int maxSize, LessThan lessThan) { */ public PriorityQueue( int maxSize, LessThan lessThan, Supplier sentinelObjectSupplier) { + this(maxSize, 2, lessThan, sentinelObjectSupplier); + } + + /** + * Create a priority queue with a configurable heap arity that is pre-filled with sentinel objects + */ + protected PriorityQueue( + int maxSize, int arity, LessThan lessThan, Supplier sentinelObjectSupplier) { final int heapSize; + if (arity < 2) { + throw new IllegalArgumentException("arity mist be >= 2; got: " + arity); + } + if (0 == maxSize) { // We allocate 1 extra to avoid if statement in top() heapSize = 2; @@ -134,6 +155,7 @@ public PriorityQueue( final T[] h = (T[]) new Object[heapSize]; this.heap = h; this.maxSize = maxSize; + this.arity = arity; this.lessThan = lessThan; // If sentinel objects are supported, populate the queue with them @@ -168,7 +190,7 @@ public void addAll(Collection elements) { } } finally { // The loop goes down to 1 as heap is 1-based not 0-based. - for (int i = (size >>> 1); i >= 1; i--) { + for (int i = (size <= 1 ? 0 : parent(size)); i >= 1; i--) { downHeap(i); } } @@ -200,7 +222,7 @@ public void addAll(Stream elements) { }); } finally { // The loop goes down to 1 as heap is 1-based not 0-based. - for (int i = (size >>> 1); i >= 1; i--) { + for (int i = (size <= 1 ? 0 : parent(size)); i >= 1; i--) { downHeap(i); } } @@ -332,11 +354,11 @@ public T[] drainToArrayHighestFirst(IntFunction newArray) { protected boolean upHeap(int origPos) { int i = origPos; T node = heap[i]; // save bottom node - int j = i >>> 1; + int j = i > 1 ? parent(i) : 0; while (j > 0 && lessThan.lessThan(node, heap[j])) { heap[i] = heap[j]; // shift parents down i = j; - j = j >>> 1; + j = i > 1 ? parent(i) : 0; } heap[i] = node; // install saved node return i != origPos; @@ -344,23 +366,34 @@ protected boolean upHeap(int origPos) { protected void downHeap(int i) { T node = heap[i]; // save top node - int j = i << 1; // find smaller child - int k = j + 1; - if (k <= size && lessThan.lessThan(heap[k], heap[j])) { - j = k; - } - while (j <= size && lessThan.lessThan(heap[j], node)) { - heap[i] = heap[j]; // shift up child - i = j; - j = i << 1; - k = j + 1; - if (k <= size && lessThan.lessThan(heap[k], heap[j])) { - j = k; + int j = firstChild(i); + while (j <= size) { + final int lastChild = Math.min(j + arity - 1, size); + int bestChild = j; + for (int c = j + 1; c <= lastChild; ++c) { + if (lessThan.lessThan(heap[c], heap[bestChild])) { + bestChild = c; + } + } + if (lessThan.lessThan(heap[bestChild], node)) { + heap[i] = heap[bestChild]; // shift up child + i = bestChild; + j = firstChild(i); + } else { + break; } } heap[i] = node; // install saved node } + private int parent(int i) { + return ((i - 2) / arity) + 1; + } + + private int firstChild(int i) { + return arity * (i - 1) + 2; + } + /** * This method returns the internal heap array as Object[]. * diff --git a/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java b/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java index c543debb29ee..272f4e7ec323 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Random; import org.apache.lucene.tests.util.LuceneTestCase; @@ -50,6 +51,12 @@ protected final void checkValidity() { } } + private static class TernaryIntegerQueue extends PriorityQueue { + public TernaryIntegerQueue(int count) { + super(count, 3, (a, b) -> a < b); + } + } + public void testZeroSizedQueue() { PriorityQueue pq = new IntegerQueue(0); assertEquals((Object) 1, pq.insertWithOverflow(1)); @@ -98,6 +105,11 @@ public void testComparatorPQ() throws Exception { testPQ(PriorityQueue.usingComparator(size, Integer::compareTo), size, random()); } + public void testTernaryPriorityQueue() { + int size = atLeast(10000); + testPQ(new TernaryIntegerQueue(size), size, random()); + } + public static void testPQ(PriorityQueue pq, int count, Random gen) { int sum = 0, sum2 = 0; @@ -196,20 +208,16 @@ public void testAddAllWithStreamNotFitIntoQueue() { } private boolean assertHeap(PriorityQueue pq) { + return assertHeap(pq, 2); + } + + private boolean assertHeap(PriorityQueue pq, int arity) { Object[] heapArray = pq.getHeapArray(); - // The loop goes down to 1 as heap is 1-based not 0-based. - for (int i = (heapArray.length >>> 1); i >= 1; i--) { - int left = i << 1; - int right = left + 1; - if (right < heapArray.length) { - if ((Integer) heapArray[i] > (Integer) heapArray[right]) { - return false; - } - if ((Integer) heapArray[i] > (Integer) heapArray[left]) { - return false; - } - } else if (left < heapArray.length) { - if ((Integer) heapArray[i] > (Integer) heapArray[left]) { + for (int i = 1; i <= pq.size(); i++) { + int firstChild = arity * (i - 1) + 2; + int lastChild = Math.min(firstChild + arity - 1, pq.size()); + for (int child = firstChild; child <= lastChild; child++) { + if ((Integer) heapArray[i] > (Integer) heapArray[child]) { return false; } } @@ -307,6 +315,37 @@ public void testRandomAdditionsAgainstJavaPq() { pq.checkValidity(); } + public void testRandomAdditionsAgainstJavaPriorityTernaryHeap() { + int maxElement = RandomNumbers.randomIntBetween(random(), 1, 500); + int size = maxElement / 2 + 1; + + var reference = new java.util.PriorityQueue(); + var pq = new TernaryIntegerQueue(size); + + Random localRandom = nonAssertingRandom(random()); + + Map ints = new HashMap<>(); + for (int i = 0, iters = size * 2; i < iters; i++) { + Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k); + + var dropped = pq.insertWithOverflow(element); + + reference.add(element); + Integer droppedReference; + if (reference.size() > size) { + droppedReference = reference.remove(); + } else { + droppedReference = null; + } + + assertEquals("insertWithOverflow() difference.", dropped, droppedReference); + assertEquals("insertWithOverflow() size difference?", reference.size(), pq.size()); + assertEquals("top() difference?", reference.peek(), pq.top()); + } + + assertTrue(assertHeap(pq, 3)); + } + public void testIteratorEmpty() { IntegerQueue queue = new IntegerQueue(3);