diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 172a8ece3e43..16e8b26bb25f 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -115,6 +115,8 @@ Improvements
* GITHUB#15225: Improve package documentation for org.apache.lucene.util. (Syed Mohammad Saad)
+* GITHUB#15936: Introduce BlockGroupingCollectorManager to parallelize search when using BlockGroupingCollector. (Binlong Gao)
+
Optimizations
---------------------
* GITHUB#15681, GITHUB#15833: Replace pre-sized array or empty array with lambda expression to call Collection#toArray. (Zhou Hui)
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
index 1f6a473f0e65..b31b3b5e3034 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
@@ -60,11 +60,6 @@
*
* @lucene.experimental
*/
-
-// TODO: TopGroups.merge() won't work with TopGroups returned by this collector, because
-// each block will be on a different shard. Add a specialized merge() static method
-// to this collector?
-
public class BlockGroupingCollector extends SimpleCollector {
private int[] pendingSubDocs;
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java
new file mode 100644
index 000000000000..73dc69e00f36
--- /dev/null
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.grouping;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import org.apache.lucene.search.CollectorManager;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.Weight;
+
+/**
+ * A {@link CollectorManager} for {@link BlockGroupingCollector} that merges results from multiple
+ * collectors into a single {@link TopGroups}. This is intended for use with concurrent search,
+ * where each segment is searched by a separate {@link BlockGroupingCollector}.
+ *
+ *
Documents must be indexed as blocks using {@link
+ * org.apache.lucene.index.IndexWriter#addDocuments IndexWriter.addDocuments()} or {@link
+ * org.apache.lucene.index.IndexWriter#updateDocuments IndexWriter.updateDocuments()}.
+ *
+ *
See {@link BlockGroupingCollector} for more details.
+ *
+ * @lucene.experimental
+ */
+public class BlockGroupingCollectorManager
+ implements CollectorManager> {
+
+ private final Sort groupSort;
+ private final int topNGroups;
+ private final boolean needsScores;
+ private final Weight lastDocPerGroup;
+
+ private final Sort withinGroupSort;
+ private final int groupOffset;
+ private final int withinGroupOffset;
+ private final int maxDocsPerGroup;
+
+ private final List collectors;
+
+ public BlockGroupingCollectorManager(
+ Sort groupSort,
+ int topNGroups,
+ boolean needsScores,
+ Weight lastDocPerGroup,
+ Sort withinGroupSort,
+ int groupOffset,
+ int withinGroupOffset,
+ int maxDocsPerGroup) {
+ this.groupSort = groupSort;
+ this.topNGroups = topNGroups;
+ this.needsScores = needsScores;
+ this.lastDocPerGroup = lastDocPerGroup;
+ this.collectors = new ArrayList<>();
+ this.withinGroupSort = withinGroupSort;
+ this.groupOffset = groupOffset;
+ this.withinGroupOffset = withinGroupOffset;
+ this.maxDocsPerGroup = maxDocsPerGroup;
+ }
+
+ @Override
+ public BlockGroupingCollector newCollector() throws IOException {
+ BlockGroupingCollector collector =
+ new BlockGroupingCollector(groupSort, topNGroups, needsScores, lastDocPerGroup);
+ collectors.add(collector);
+ return collector;
+ }
+
+ @Override
+ public TopGroups> reduce(Collection collectors) throws IOException {
+ if (collectors.isEmpty()) {
+ return null;
+ }
+
+ if (collectors.size() == 1) {
+ return collectors
+ .iterator()
+ .next()
+ .getTopGroups(withinGroupSort, groupOffset, withinGroupOffset, maxDocsPerGroup);
+ }
+
+ // Merge results from multiple collectors
+ List> shardGroupsList = new ArrayList<>();
+ for (BlockGroupingCollector collector : collectors) {
+ TopGroups> topGroups =
+ collector.getTopGroups(withinGroupSort, 0, withinGroupOffset, maxDocsPerGroup);
+ if (topGroups != null) {
+ shardGroupsList.add(topGroups);
+ }
+ }
+
+ if (shardGroupsList.isEmpty()) {
+ return null;
+ }
+
+ return TopGroups.mergeBlockGroups(shardGroupsList, groupSort, groupOffset, topNGroups);
+ }
+}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
index 91bcbf56da84..bd0893dbfc38 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
@@ -196,15 +196,18 @@ protected TopGroups> groupByDocBlock(
final Query endDocsQuery = searcher.rewrite(this.groupEndDocs);
final Weight groupEndDocs =
searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1);
- BlockGroupingCollector c =
- new BlockGroupingCollector(
+ BlockGroupingCollectorManager bcm =
+ new BlockGroupingCollectorManager(
groupSort,
topN,
groupSort.needsScores() || sortWithinGroup.needsScores(),
- groupEndDocs);
- searcher.search(query, c);
- int topNInsideGroup = groupDocsOffset + groupDocsLimit;
- return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup);
+ groupEndDocs,
+ sortWithinGroup,
+ groupOffset,
+ groupDocsOffset,
+ groupDocsOffset + groupDocsLimit);
+
+ return searcher.search(query, bcm);
}
/**
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
index 0a1c607ab78b..c87da5ce2af5 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
@@ -16,6 +16,13 @@
*/
package org.apache.lucene.search.grouping;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.NavigableSet;
+import java.util.TreeSet;
+import org.apache.lucene.search.FieldComparator;
+import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
@@ -280,4 +287,123 @@ public static TopGroups merge(
totalMaxScore);
}
}
+
+ private record MergedBlockGroup(Object[] topValues, int shardIndex, int groupIndex) {}
+
+ private static class GroupComparator implements Comparator {
+ @SuppressWarnings("rawtypes")
+ public final FieldComparator[] comparators;
+
+ public final int[] reversed;
+
+ @SuppressWarnings({"rawtypes"})
+ public GroupComparator(Sort groupSort) {
+ final SortField[] sortFields = groupSort.getSort();
+ comparators = new FieldComparator[sortFields.length];
+ reversed = new int[sortFields.length];
+ for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
+ final SortField sortField = sortFields[compIDX];
+ comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
+ reversed[compIDX] = sortField.getReverse() ? -1 : 1;
+ }
+ }
+
+ @Override
+ @SuppressWarnings({"unchecked"})
+ public int compare(MergedBlockGroup group, MergedBlockGroup other) {
+ if (group == other) {
+ return 0;
+ }
+ final Object[] groupValues = group.topValues;
+ final Object[] otherValues = other.topValues;
+ for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
+ final int c =
+ reversed[compIDX]
+ * comparators[compIDX].compareValues(groupValues[compIDX], otherValues[compIDX]);
+ if (c != 0) {
+ return c;
+ }
+ }
+
+ assert group.shardIndex != other.shardIndex;
+ return group.shardIndex - other.shardIndex;
+ }
+ }
+
+ /**
+ * Merge TopGroups that are partitioned into blocks per shard. This method assumes that within
+ * each shard, the groups are sorted according to the groupSort.
+ *
+ * @param shardGroups list of TopGroups, one per shard.
+ * @param groupSort The {@link Sort} used to sort the groups. The top sorted document within each
+ * * group according to groupSort, determines how that group sorts against other groups. This
+ * * must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE.
+ * @param groupOffset Which group to start from.
+ * @param topNGroups How many top groups to keep.
+ * @return TopGroups instance or null if there are no groups.
+ */
+ public static TopGroups> mergeBlockGroups(
+ List> shardGroups, Sort groupSort, int groupOffset, int topNGroups) {
+ if (shardGroups.isEmpty()) {
+ return null;
+ }
+
+ int totalGroupCount = 0;
+ int totalHitCount = 0;
+ int totalGroupedHitCount = 0;
+ for (TopGroups> sg : shardGroups) {
+ totalGroupCount += sg.totalGroupCount;
+ totalHitCount += sg.totalHitCount;
+ }
+
+ // k-way merge
+ GroupComparator groupComp = new GroupComparator(groupSort);
+ NavigableSet queue = new TreeSet<>(groupComp);
+
+ // init queue
+ for (int idx = 0; idx < shardGroups.size(); idx++) {
+ GroupDocs> firstGroupDocs = shardGroups.get(idx).groups[0];
+ queue.add(new MergedBlockGroup(firstGroupDocs.groupSortValues(), idx, 0));
+ }
+
+ float maxScore = shardGroups.get(queue.first().shardIndex).groups[0].maxScore();
+
+ final List> groupDocsList = new ArrayList<>();
+ int count = 0;
+ while (!queue.isEmpty()) {
+ final MergedBlockGroup mergedBlockGroup = queue.pollFirst();
+ TopGroups> shardGroup = shardGroups.get(mergedBlockGroup.shardIndex);
+
+ int currentGroupIndex = mergedBlockGroup.groupIndex;
+ GroupDocs> currentGroupDocs = shardGroup.groups[currentGroupIndex];
+ if (count++ >= groupOffset) {
+ groupDocsList.add(currentGroupDocs);
+ totalGroupedHitCount += currentGroupDocs.totalHits().value();
+ if (groupDocsList.size() == topNGroups) {
+ break;
+ }
+ }
+
+ int nextGroupIndex = currentGroupIndex + 1;
+ if (nextGroupIndex < shardGroup.groups.length) {
+ GroupDocs> nextGroupDocs = shardGroup.groups[nextGroupIndex];
+ queue.add(
+ new MergedBlockGroup(
+ nextGroupDocs.groupSortValues(), mergedBlockGroup.shardIndex, nextGroupIndex));
+ }
+ }
+
+ @SuppressWarnings({"unchecked"})
+ GroupDocs