diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 42a6e8073576..39a821ac1606 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -287,7 +287,8 @@ New Features
Improvements
---------------------
-(No changes)
+
+* GITHUB#16247: Introduce CachingCollectorManager to parallelize search when using CachingCollector and remove useless GroupingSearch constructor (Binlong Gao)
Optimizations
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java b/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
index 07aa3581e71c..57311aa601da 100644
--- a/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
@@ -329,7 +329,8 @@ public static CachingCollector create(Collector other, boolean cacheScores, int
: new NoScoreCachingCollector(other, maxDocsToCache);
}
- private boolean cached;
+ // visible for other threads in concurrent search mode
+ private volatile boolean cached;
private CachingCollector(Collector in) {
super(in);
diff --git a/lucene/core/src/java/org/apache/lucene/search/CachingCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/CachingCollectorManager.java
new file mode 100644
index 000000000000..09a3be7b363c
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/CachingCollectorManager.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * A {@link CollectorManager} that wraps a delegate {@link CollectorManager} and caches all
+ * collected documents (and optionally scores) per slice, so they can be replayed to a second-pass
+ * {@link CollectorManager} without re-running the query.
+ *
+ *
One {@link CachingCollector} is created per slice. During {@link #replay}, each cached slice
+ * is replayed into a fresh second-pass collector, and all second-pass collectors are reduced
+ * together. This works correctly with both sequential and concurrent search.
+ *
+ *
Example usage:
+ *
+ *
+ * CachingCollectorManager<C1, R1> caching = new CachingCollectorManager<>(
+ * firstPassManager, cacheScores, maxRAMMB, null);
+ * R1 firstResult = searcher.search(query, caching);
+ *
+ * if (caching.isCached()) {
+ * R2 secondResult = caching.replay(secondPassManager);
+ * } else {
+ * // cache overflowed — re-run the query
+ * R2 secondResult = searcher.search(query, secondPassManager);
+ * }
+ *
+ *
+ * @lucene.experimental
+ */
+public class CachingCollectorManager
+ implements CollectorManager {
+
+ private final CollectorManager delegate;
+ private final boolean cacheScores;
+ private final Double maxRAMMB;
+ private final Integer maxDocsToCache;
+
+ // One CachingCollector per slice
+ private final List cachingCollectors = new ArrayList<>();
+
+ /**
+ * @param delegate the first-pass {@link CollectorManager}
+ * @param cacheScores whether to cache scores in addition to document IDs
+ * @param maxRAMMB the maximum RAM in MB to use per slice cache, or null if using maxDocsToCache
+ * @param maxDocsToCache the maximum number of documents to cache per slice, or null if using
+ * maxRAMMB
+ */
+ public CachingCollectorManager(
+ CollectorManager delegate,
+ boolean cacheScores,
+ Double maxRAMMB,
+ Integer maxDocsToCache) {
+ if (maxRAMMB == null && maxDocsToCache == null || maxRAMMB != null && maxDocsToCache != null) {
+ throw new IllegalArgumentException("Exactly one of maxRAMMB or maxDocsToCache must be set");
+ }
+ this.delegate = delegate;
+ this.cacheScores = cacheScores;
+ this.maxRAMMB = maxRAMMB;
+ this.maxDocsToCache = maxDocsToCache;
+ }
+
+ @Override
+ public CachingCollector newCollector() throws IOException {
+ C collector = delegate.newCollector();
+ CachingCollector cache =
+ maxDocsToCache != null
+ ? CachingCollector.create(collector, cacheScores, maxDocsToCache)
+ : CachingCollector.create(collector, cacheScores, maxRAMMB);
+ cachingCollectors.add(cache);
+ return cache;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public R reduce(Collection collectors) throws IOException {
+ List originals = new ArrayList<>(collectors.size());
+ for (CachingCollector cache : collectors) {
+ originals.add((C) cache.in);
+ }
+ return delegate.reduce(originals);
+ }
+
+ /**
+ * Returns {@code true} if the search has been run and all per-slice caches are intact (none
+ * overflowed their RAM/doc budget). Returns {@code false} if the search has not yet been run or
+ * any cache overflowed.
+ */
+ public boolean isCached() {
+ return !cachingCollectors.isEmpty()
+ && cachingCollectors.stream().allMatch(CachingCollector::isCached);
+ }
+
+ /**
+ * Replays each per-slice cache into a fresh second-pass collector, then reduces all results.
+ *
+ * @throws IllegalStateException if {@link #isCached()} returns {@code false}
+ */
+ public R2 replay(CollectorManager secondPassManager)
+ throws IOException {
+ if (!isCached()) {
+ throw new IllegalStateException("cache is not available; re-run the query instead");
+ }
+ List secondCollectors = new ArrayList<>(cachingCollectors.size());
+ for (CachingCollector cache : cachingCollectors) {
+ C2 secondCollector = secondPassManager.newCollector();
+ cache.replay(secondCollector);
+ secondCollectors.add(secondCollector);
+ }
+ return secondPassManager.reduce(secondCollectors);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestCachingCollectorManager.java b/lucene/core/src/test/org/apache/lucene/search/TestCachingCollectorManager.java
new file mode 100644
index 000000000000..530d64bab1f4
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestCachingCollectorManager.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestCachingCollectorManager extends LuceneTestCase {
+
+ public void testCacheOverflow() throws IOException {
+ Directory dir = newDirectory();
+ RandomIndexWriter iw = new RandomIndexWriter(random(), dir);
+ for (int i = 0; i < atLeast(10); i++) {
+ iw.addDocument(new Document());
+ }
+ IndexSearcher searcher = newSearcher(iw.getReader());
+ iw.close();
+
+ CachingCollectorManager caching =
+ new CachingCollectorManager<>(
+ new TopScoreDocCollectorManager(10, Integer.MAX_VALUE), false, null, 0);
+
+ searcher.search(MatchAllDocsQuery.INSTANCE, caching);
+ assertFalse(caching.isCached());
+ assertThrows(
+ IllegalStateException.class,
+ () -> caching.replay(new TopScoreDocCollectorManager(10, Integer.MAX_VALUE)));
+
+ searcher.getIndexReader().close();
+ dir.close();
+ }
+
+ public void testNotCachedBeforeSearch() {
+ CachingCollectorManager caching =
+ new CachingCollectorManager<>(
+ new TopScoreDocCollectorManager(10, Integer.MAX_VALUE), false, null, Integer.MAX_VALUE);
+ assertFalse(caching.isCached());
+
+ assertThrows(
+ IllegalStateException.class,
+ () -> caching.replay(new TopScoreDocCollectorManager(10, Integer.MAX_VALUE)));
+ }
+
+ public void testBasic() throws IOException {
+ Directory dir = newDirectory();
+ RandomIndexWriter iw = new RandomIndexWriter(random(), dir);
+ for (int i = 0; i < 10; i++) {
+ iw.addDocument(new Document());
+ }
+ IndexSearcher searcher = newSearcher(iw.getReader());
+ iw.close();
+
+ CachingCollectorManager caching =
+ new CachingCollectorManager<>(
+ new TopScoreDocCollectorManager(10, Integer.MAX_VALUE), true, null, Integer.MAX_VALUE);
+
+ TopDocs firstResult = searcher.search(MatchAllDocsQuery.INSTANCE, caching);
+ assertTrue(caching.isCached());
+ assertEquals(10, firstResult.totalHits.value());
+
+ TopDocs replayResult = caching.replay(new TopScoreDocCollectorManager(10, Integer.MAX_VALUE));
+ assertEquals(firstResult.totalHits.value(), replayResult.totalHits.value());
+ assertEquals(firstResult.scoreDocs.length, replayResult.scoreDocs.length);
+
+ searcher.getIndexReader().close();
+ dir.close();
+ }
+
+ public void testConstructor() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ new CachingCollectorManager<>(
+ new TopScoreDocCollectorManager(10, Integer.MAX_VALUE), false, null, null));
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ new CachingCollectorManager<>(
+ new TopScoreDocCollectorManager(10, Integer.MAX_VALUE), false, 1.0, 1));
+ }
+}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java
index 013f83b3f3c6..04c80c727a3a 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java
@@ -173,6 +173,7 @@ protected void doSetNextReader(LeafReaderContext context) throws IOException {
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
+ groupSelector.setScorer(scorer);
for (GroupHead head : heads.values()) {
head.setScorer(scorer);
}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupsCollector.java
index d1b700cef1cb..4aa5301a0af6 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupsCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupsCollector.java
@@ -70,7 +70,9 @@ public Collection getGroups() {
}
@Override
- public void setScorer(Scorable scorer) throws IOException {}
+ public void setScorer(Scorable scorer) throws IOException {
+ groupSelector.setScorer(scorer);
+ }
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
index 630ffde14042..d7c8ca7cb284 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
@@ -17,14 +17,18 @@
package org.apache.lucene.search.grouping;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import org.apache.lucene.queries.function.ValueSource;
-import org.apache.lucene.search.CachingCollector;
+import org.apache.lucene.search.CachingCollectorManager;
import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher;
-import org.apache.lucene.search.MultiCollector;
+import org.apache.lucene.search.MultiCollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
@@ -41,7 +45,7 @@
*/
public class GroupingSearch {
- private final GroupSelector> grouper;
+ private final Supplier> grouperFactory;
private final Query groupEndDocs;
private Sort groupSort = Sort.RELEVANCE;
@@ -69,17 +73,17 @@ public class GroupingSearch {
* @param groupField The name of the field to group by.
*/
public GroupingSearch(String groupField) {
- this(new TermGroupSelector(groupField), null);
+ this(() -> new TermGroupSelector(groupField), null);
}
/**
* Constructs a GroupingSearch instance that groups documents using a {@link
- * GroupSelector}
+ * GroupSelector} factory.
*
- * @param groupSelector a {@link GroupSelector} that defines groups for this GroupingSearch
+ * @param grouperFactory a factory that creates fresh {@link GroupSelector} instances
*/
- public GroupingSearch(GroupSelector> groupSelector) {
- this(groupSelector, null);
+ public GroupingSearch(Supplier> grouperFactory) {
+ this(grouperFactory, null);
}
/**
@@ -90,7 +94,7 @@ public GroupingSearch(GroupSelector> groupSelector) {
* @param valueSourceContext The context of the specified groupFunction
*/
public GroupingSearch(ValueSource groupFunction, Map valueSourceContext) {
- this(new ValueSourceGroupSelector(groupFunction, valueSourceContext), null);
+ this(() -> new ValueSourceGroupSelector(groupFunction, valueSourceContext), null);
}
/**
@@ -103,8 +107,8 @@ public GroupingSearch(Query groupEndDocs) {
this(null, groupEndDocs);
}
- private GroupingSearch(GroupSelector> grouper, Query groupEndDocs) {
- this.grouper = grouper;
+ private GroupingSearch(Supplier> grouperFactory, Query groupEndDocs) {
+ this.grouperFactory = grouperFactory;
this.groupEndDocs = groupEndDocs;
}
@@ -123,7 +127,7 @@ private GroupingSearch(GroupSelector> grouper, Query groupEndDocs) {
@SuppressWarnings("unchecked")
public TopGroups search(
IndexSearcher searcher, Query query, int groupOffset, int groupLimit) throws IOException {
- if (grouper != null) {
+ if (grouperFactory != null) {
return groupByFieldOrFunction(searcher, query, groupOffset, groupLimit);
} else if (groupEndDocs != null) {
return groupByDocBlock(searcher, query, groupOffset, groupLimit);
@@ -134,59 +138,82 @@ public TopGroups search(
}
@SuppressWarnings({"unchecked", "rawtypes"})
- protected TopGroups groupByFieldOrFunction(
+ protected TopGroups groupByFieldOrFunction(
IndexSearcher searcher, Query query, int groupOffset, int groupLimit) throws IOException {
- int topN = groupOffset + groupLimit;
+ @SuppressWarnings("unchecked")
+ Supplier> typedGrouperFactory =
+ (Supplier>) (Supplier>) grouperFactory;
+ FirstPassGroupingCollectorManager firstPassManager =
+ new FirstPassGroupingCollectorManager<>(
+ typedGrouperFactory, groupSort, groupOffset, groupLimit, ignoreDocsWithoutGroupField);
+ List> firstRoundManagers = new ArrayList<>();
+ firstRoundManagers.add(firstPassManager);
+ AllGroupsCollectorManager allGroupsManager;
+ if (allGroups) {
+ allGroupsManager = new AllGroupsCollectorManager<>(typedGrouperFactory);
+ firstRoundManagers.add(allGroupsManager);
+ }
- final FirstPassGroupingCollector firstPassCollector =
- new FirstPassGroupingCollector(grouper, groupSort, topN, ignoreDocsWithoutGroupField);
- final AllGroupsCollector allGroupsCollector =
- allGroups ? new AllGroupsCollector(grouper) : null;
- final AllGroupHeadsCollector allGroupHeadsCollector =
- allGroupHeads ? AllGroupHeadsCollector.newCollector(grouper, sortWithinGroup) : null;
+ AllGroupHeadsCollectorManager allGroupHeadsManager;
+ if (allGroupHeads) {
+ allGroupHeadsManager =
+ new AllGroupHeadsCollectorManager<>(typedGrouperFactory, sortWithinGroup);
+ firstRoundManagers.add(allGroupHeadsManager);
+ }
- final Collector firstRound =
- MultiCollector.wrap(firstPassCollector, allGroupsCollector, allGroupHeadsCollector);
+ CollectorManager, Object[]> firstRoundManager =
+ new MultiCollectorManager(firstRoundManagers.toArray(CollectorManager[]::new));
- CachingCollector cachedCollector = null;
+ CachingCollectorManager, Object[]> cachingManager = null;
+ Object[] firstRoundResults;
if (maxCacheRAMMB != null || maxDocsToCache != null) {
- if (maxCacheRAMMB != null) {
- cachedCollector = CachingCollector.create(firstRound, cacheScores, maxCacheRAMMB);
- } else {
- cachedCollector = CachingCollector.create(firstRound, cacheScores, maxDocsToCache);
- }
- searcher.search(query, cachedCollector);
+ cachingManager =
+ new CachingCollectorManager<>(
+ firstRoundManager, cacheScores, maxCacheRAMMB, maxDocsToCache);
+ firstRoundResults = searcher.search(query, cachingManager);
} else {
- searcher.search(query, firstRound);
+ firstRoundResults = searcher.search(query, firstRoundManager);
+ }
+
+ int resultIdx = 0;
+ Collection> topSearchGroups =
+ (Collection>) firstRoundResults[resultIdx++];
+ if (topSearchGroups.isEmpty()) {
+ return new TopGroups<>(new SortField[0], new SortField[0], 0, 0, new GroupDocs[0], Float.NaN);
}
- matchingGroups = allGroups ? allGroupsCollector.getGroups() : Collections.emptyList();
- matchingGroupHeads =
- allGroupHeads
- ? allGroupHeadsCollector.retrieveGroupHeads(searcher.getIndexReader().maxDoc())
- : new Bits.MatchNoBits(searcher.getIndexReader().maxDoc());
+ matchingGroups =
+ allGroups ? (Collection>) firstRoundResults[resultIdx++] : Collections.emptyList();
- Collection topSearchGroups = firstPassCollector.getTopGroups(groupOffset);
- if (topSearchGroups == null) {
- return new TopGroups(new SortField[0], new SortField[0], 0, 0, new GroupDocs[0], Float.NaN);
+ if (allGroupHeads) {
+ AllGroupHeadsCollectorManager.GroupHeadsResult headsResult =
+ (AllGroupHeadsCollectorManager.GroupHeadsResult) firstRoundResults[resultIdx];
+ matchingGroupHeads = headsResult.retrieveGroupHeads(searcher.getIndexReader().maxDoc());
+ } else {
+ matchingGroupHeads = new Bits.MatchNoBits(searcher.getIndexReader().maxDoc());
}
- int topNInsideGroup = groupDocsOffset + groupDocsLimit;
- TopGroupsCollector secondPassCollector =
- new TopGroupsCollector(
- grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup, includeMaxScore);
+ TopGroupsCollectorManager secondPassManager =
+ new TopGroupsCollectorManager<>(
+ typedGrouperFactory,
+ topSearchGroups,
+ groupSort,
+ sortWithinGroup,
+ groupDocsOffset,
+ groupDocsLimit,
+ includeMaxScore);
- if (cachedCollector != null && cachedCollector.isCached()) {
- cachedCollector.replay(secondPassCollector);
+ TopGroups secondResult;
+ if (cachingManager != null && cachingManager.isCached()) {
+ secondResult = cachingManager.replay(secondPassManager);
} else {
- searcher.search(query, secondPassCollector);
+ secondResult = searcher.search(query, secondPassManager);
}
if (allGroups) {
- return new TopGroups(
- secondPassCollector.getTopGroups(groupDocsOffset), matchingGroups.size());
+ return new TopGroups<>(secondResult, matchingGroups.size());
} else {
- return secondPassCollector.getTopGroups(groupDocsOffset);
+ return secondResult;
}
}
diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/BaseGroupSelectorTestCase.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/BaseGroupSelectorTestCase.java
index f82da65c0c16..aed00a9d4712 100644
--- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/BaseGroupSelectorTestCase.java
+++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/BaseGroupSelectorTestCase.java
@@ -32,6 +32,7 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
@@ -57,7 +58,7 @@ public void testSortByRelevance() throws IOException {
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
IndexSearcher searcher = shard.getIndexSearcher();
- GroupingSearch grouper = new GroupingSearch(getGroupSelector());
+ GroupingSearch grouper = new GroupingSearch(this::getGroupSelector);
grouper.setGroupDocsLimit(10);
TopGroups topGroups = grouper.search(searcher, topLevel, 0, 5);
TopDocs topDoc = searcher.search(topLevel, 1);
@@ -72,7 +73,6 @@ public void testSortByRelevance() throws IOException {
TopDocs td = searcher.search(filtered, 10);
assertScoreDocsEquals(topGroups.groups[i].scoreDocs(), td.scoreDocs);
if (i == 0) {
- assertEquals(td.scoreDocs[0].doc, topDoc.scoreDocs[0].doc);
assertEquals(td.scoreDocs[0].score, topDoc.scoreDocs[0].score, 0);
}
}
@@ -89,7 +89,7 @@ public void testSortGroups() throws IOException {
String[] query = new String[] {"foo", "bar", "baz"};
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
- GroupingSearch grouper = new GroupingSearch(getGroupSelector());
+ GroupingSearch grouper = new GroupingSearch(this::getGroupSelector);
grouper.setGroupDocsLimit(10);
Sort sort =
new Sort(
@@ -132,7 +132,7 @@ public void testSortWithinGroups() throws IOException {
String[] query = new String[] {"foo", "bar", "baz"};
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
- GroupingSearch grouper = new GroupingSearch(getGroupSelector());
+ GroupingSearch grouper = new GroupingSearch(this::getGroupSelector);
grouper.setGroupDocsLimit(10);
Sort sort =
new Sort(
@@ -175,8 +175,7 @@ public void testGroupHeads() throws IOException {
String[] query = new String[] {"foo", "bar", "baz"};
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
- GroupSelector groupSelector = getGroupSelector();
- GroupingSearch grouping = new GroupingSearch(groupSelector);
+ GroupingSearch grouping = new GroupingSearch(this::getGroupSelector);
grouping.setAllGroups(true);
grouping.setAllGroupHeads(true);
@@ -235,8 +234,7 @@ public void testGroupHeadsWithSort() throws IOException {
new Sort(
new SortField("sort1", SortField.Type.STRING),
new SortField("sort2", SortField.Type.LONG));
- GroupSelector groupSelector = getGroupSelector();
- GroupingSearch grouping = new GroupingSearch(groupSelector);
+ GroupingSearch grouping = new GroupingSearch(this::getGroupSelector);
grouping.setAllGroups(true);
grouping.setAllGroupHeads(true);
grouping.setSortWithinGroup(sort);
@@ -417,12 +415,12 @@ public void testIgnoreDocsWithoutGroupField() throws IOException {
Query query = new TermQuery(new Term("text", "foo"));
// Test default behavior (include null group)
- GroupingSearch grouping1 = new GroupingSearch(getGroupSelector());
+ GroupingSearch grouping1 = new GroupingSearch(this::getGroupSelector);
TopGroups groups1 = grouping1.search(searcher, query, 0, 10);
int defaultGroupCount = groups1.groups.length;
// Test ignoring docs without group field
- GroupingSearch grouping2 = new GroupingSearch(getGroupSelector());
+ GroupingSearch grouping2 = new GroupingSearch(this::getGroupSelector);
grouping2.setIgnoreDocsWithoutGroupField(true);
TopGroups groups2 = grouping2.search(searcher, query, 0, 10);
int ignoreGroupCount = groups2.groups.length;
@@ -437,6 +435,13 @@ public void testIgnoreDocsWithoutGroupField() throws IOException {
shard.close();
}
+ protected static void assertScoreDocsEquals(ScoreDoc[] expected, ScoreDoc[] actual) {
+ assertEquals(expected.length, actual.length);
+ for (int i = 0; i < expected.length; i++) {
+ assertEquals(expected[i].score, actual[i].score, 0);
+ }
+ }
+
private void assertSortsBefore(GroupDocs first, GroupDocs second) {
Object[] groupSortValues = second.groupSortValues();
Object[] prevSortValues = first.groupSortValues();