diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 172a8ece3e43..16e8b26bb25f 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -115,6 +115,8 @@ Improvements * GITHUB#15225: Improve package documentation for org.apache.lucene.util. (Syed Mohammad Saad) +* GITHUB#15936: Introduce BlockGroupingCollectorManager to parallelize search when using BlockGroupingCollector. (Binlong Gao) + Optimizations --------------------- * GITHUB#15681, GITHUB#15833: Replace pre-sized array or empty array with lambda expression to call Collection#toArray. (Zhou Hui) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java index 1f6a473f0e65..b31b3b5e3034 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java @@ -60,11 +60,6 @@ * * @lucene.experimental */ - -// TODO: TopGroups.merge() won't work with TopGroups returned by this collector, because -// each block will be on a different shard. Add a specialized merge() static method -// to this collector? - public class BlockGroupingCollector extends SimpleCollector { private int[] pendingSubDocs; diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java new file mode 100644 index 000000000000..73dc69e00f36 --- /dev/null +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollectorManager.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.grouping; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.Weight; + +/** + * A {@link CollectorManager} for {@link BlockGroupingCollector} that merges results from multiple + * collectors into a single {@link TopGroups}. This is intended for use with concurrent search, + * where each segment is searched by a separate {@link BlockGroupingCollector}. + * + *

Documents must be indexed as blocks using {@link + * org.apache.lucene.index.IndexWriter#addDocuments IndexWriter.addDocuments()} or {@link + * org.apache.lucene.index.IndexWriter#updateDocuments IndexWriter.updateDocuments()}. + * + *

See {@link BlockGroupingCollector} for more details. + * + * @lucene.experimental + */ +public class BlockGroupingCollectorManager + implements CollectorManager> { + + private final Sort groupSort; + private final int topNGroups; + private final boolean needsScores; + private final Weight lastDocPerGroup; + + private final Sort withinGroupSort; + private final int groupOffset; + private final int withinGroupOffset; + private final int maxDocsPerGroup; + + private final List collectors; + + public BlockGroupingCollectorManager( + Sort groupSort, + int topNGroups, + boolean needsScores, + Weight lastDocPerGroup, + Sort withinGroupSort, + int groupOffset, + int withinGroupOffset, + int maxDocsPerGroup) { + this.groupSort = groupSort; + this.topNGroups = topNGroups; + this.needsScores = needsScores; + this.lastDocPerGroup = lastDocPerGroup; + this.collectors = new ArrayList<>(); + this.withinGroupSort = withinGroupSort; + this.groupOffset = groupOffset; + this.withinGroupOffset = withinGroupOffset; + this.maxDocsPerGroup = maxDocsPerGroup; + } + + @Override + public BlockGroupingCollector newCollector() throws IOException { + BlockGroupingCollector collector = + new BlockGroupingCollector(groupSort, topNGroups, needsScores, lastDocPerGroup); + collectors.add(collector); + return collector; + } + + @Override + public TopGroups reduce(Collection collectors) throws IOException { + if (collectors.isEmpty()) { + return null; + } + + if (collectors.size() == 1) { + return collectors + .iterator() + .next() + .getTopGroups(withinGroupSort, groupOffset, withinGroupOffset, maxDocsPerGroup); + } + + // Merge results from multiple collectors + List> shardGroupsList = new ArrayList<>(); + for (BlockGroupingCollector collector : collectors) { + TopGroups topGroups = + collector.getTopGroups(withinGroupSort, 0, withinGroupOffset, maxDocsPerGroup); + if (topGroups != null) { + shardGroupsList.add(topGroups); + } + } + + if (shardGroupsList.isEmpty()) { + return null; + } + + return TopGroups.mergeBlockGroups(shardGroupsList, groupSort, groupOffset, topNGroups); + } +} diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java index 91bcbf56da84..bd0893dbfc38 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java @@ -196,15 +196,18 @@ protected TopGroups groupByDocBlock( final Query endDocsQuery = searcher.rewrite(this.groupEndDocs); final Weight groupEndDocs = searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1); - BlockGroupingCollector c = - new BlockGroupingCollector( + BlockGroupingCollectorManager bcm = + new BlockGroupingCollectorManager( groupSort, topN, groupSort.needsScores() || sortWithinGroup.needsScores(), - groupEndDocs); - searcher.search(query, c); - int topNInsideGroup = groupDocsOffset + groupDocsLimit; - return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup); + groupEndDocs, + sortWithinGroup, + groupOffset, + groupDocsOffset, + groupDocsOffset + groupDocsLimit); + + return searcher.search(query, bcm); } /** diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java index 0a1c607ab78b..c87da5ce2af5 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java @@ -16,6 +16,13 @@ */ package org.apache.lucene.search.grouping; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.NavigableSet; +import java.util.TreeSet; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.Pruning; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; @@ -280,4 +287,123 @@ public static TopGroups merge( totalMaxScore); } } + + private record MergedBlockGroup(Object[] topValues, int shardIndex, int groupIndex) {} + + private static class GroupComparator implements Comparator { + @SuppressWarnings("rawtypes") + public final FieldComparator[] comparators; + + public final int[] reversed; + + @SuppressWarnings({"rawtypes"}) + public GroupComparator(Sort groupSort) { + final SortField[] sortFields = groupSort.getSort(); + comparators = new FieldComparator[sortFields.length]; + reversed = new int[sortFields.length]; + for (int compIDX = 0; compIDX < sortFields.length; compIDX++) { + final SortField sortField = sortFields[compIDX]; + comparators[compIDX] = sortField.getComparator(1, Pruning.NONE); + reversed[compIDX] = sortField.getReverse() ? -1 : 1; + } + } + + @Override + @SuppressWarnings({"unchecked"}) + public int compare(MergedBlockGroup group, MergedBlockGroup other) { + if (group == other) { + return 0; + } + final Object[] groupValues = group.topValues; + final Object[] otherValues = other.topValues; + for (int compIDX = 0; compIDX < comparators.length; compIDX++) { + final int c = + reversed[compIDX] + * comparators[compIDX].compareValues(groupValues[compIDX], otherValues[compIDX]); + if (c != 0) { + return c; + } + } + + assert group.shardIndex != other.shardIndex; + return group.shardIndex - other.shardIndex; + } + } + + /** + * Merge TopGroups that are partitioned into blocks per shard. This method assumes that within + * each shard, the groups are sorted according to the groupSort. + * + * @param shardGroups list of TopGroups, one per shard. + * @param groupSort The {@link Sort} used to sort the groups. The top sorted document within each + * * group according to groupSort, determines how that group sorts against other groups. This + * * must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE. + * @param groupOffset Which group to start from. + * @param topNGroups How many top groups to keep. + * @return TopGroups instance or null if there are no groups. + */ + public static TopGroups mergeBlockGroups( + List> shardGroups, Sort groupSort, int groupOffset, int topNGroups) { + if (shardGroups.isEmpty()) { + return null; + } + + int totalGroupCount = 0; + int totalHitCount = 0; + int totalGroupedHitCount = 0; + for (TopGroups sg : shardGroups) { + totalGroupCount += sg.totalGroupCount; + totalHitCount += sg.totalHitCount; + } + + // k-way merge + GroupComparator groupComp = new GroupComparator(groupSort); + NavigableSet queue = new TreeSet<>(groupComp); + + // init queue + for (int idx = 0; idx < shardGroups.size(); idx++) { + GroupDocs firstGroupDocs = shardGroups.get(idx).groups[0]; + queue.add(new MergedBlockGroup(firstGroupDocs.groupSortValues(), idx, 0)); + } + + float maxScore = shardGroups.get(queue.first().shardIndex).groups[0].maxScore(); + + final List> groupDocsList = new ArrayList<>(); + int count = 0; + while (!queue.isEmpty()) { + final MergedBlockGroup mergedBlockGroup = queue.pollFirst(); + TopGroups shardGroup = shardGroups.get(mergedBlockGroup.shardIndex); + + int currentGroupIndex = mergedBlockGroup.groupIndex; + GroupDocs currentGroupDocs = shardGroup.groups[currentGroupIndex]; + if (count++ >= groupOffset) { + groupDocsList.add(currentGroupDocs); + totalGroupedHitCount += currentGroupDocs.totalHits().value(); + if (groupDocsList.size() == topNGroups) { + break; + } + } + + int nextGroupIndex = currentGroupIndex + 1; + if (nextGroupIndex < shardGroup.groups.length) { + GroupDocs nextGroupDocs = shardGroup.groups[nextGroupIndex]; + queue.add( + new MergedBlockGroup( + nextGroupDocs.groupSortValues(), mergedBlockGroup.shardIndex, nextGroupIndex)); + } + } + + @SuppressWarnings({"unchecked"}) + GroupDocs[] groupDocs = (GroupDocs[]) groupDocsList.toArray(GroupDocs[]::new); + + return new TopGroups<>( + new TopGroups<>( + shardGroups.getFirst().groupSort, + shardGroups.getFirst().withinGroupSort, + totalHitCount, + totalGroupedHitCount, + groupDocs, + maxScore), + totalGroupCount); + } } diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestBlockGrouping.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestBlockGrouping.java index 18c02422f840..ad5561710228 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestBlockGrouping.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestBlockGrouping.java @@ -77,6 +77,64 @@ public void testSimple() throws IOException { shard.close(); } + public void testShardedBlockGrouping() throws IOException { + Shard shardControl = new Shard(); + int shardCount = random().nextInt(3) + 2; + Shard[] shards = new Shard[shardCount]; + for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) { + shards[shardIdx] = new Shard(); + // int bookCount = atLeast(20); + for (int bookIdx = 0; bookIdx < 5; bookIdx++) { + List block = new ArrayList<>(); + String bookName = "book" + shardIdx + bookIdx; + // int chapterCount = atLeast(10); + int chapterCount = 2; + for (int j = 0; j < 2; j++) { + Document doc = new Document(); + String chapterName = "chapter" + j; + String chapterText = randomText(); + doc.add(new TextField("book", bookName, Field.Store.YES)); + doc.add(new TextField("chapter", chapterName, Field.Store.YES)); + doc.add(new TextField("text", chapterText, Field.Store.NO)); + doc.add(new NumericDocValuesField("length", chapterText.length())); + doc.add(new SortedDocValuesField("book", new BytesRef(bookName))); + if (j == chapterCount - 1) { + doc.add(new TextField("blockEnd", "true", Field.Store.NO)); + } + block.add(doc); + } + shards[shardIdx].writer.addDocuments(block); + shardControl.writer.addDocuments(block); + } + } + + IndexSearcher shardControlIndexSearcher = shardControl.getIndexSearcher(); + + Query blockEndQuery = new TermQuery(new Term("blockEnd", "true")); + GroupingSearch grouper = new GroupingSearch(blockEndQuery); + grouper.setGroupDocsLimit(10); + + Query topLevel = new TermQuery(new Term("text", "grandmother")); + TopGroups singleShardTopGroups = grouper.search(shardControlIndexSearcher, topLevel, 0, 5); + + List> shardTopGroups = new ArrayList<>(); + for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) { + shardTopGroups.add(grouper.search(shards[shardIdx].getIndexSearcher(), topLevel, 0, 5)); + } + + TopGroups mergedTopGroups = TopGroups.mergeBlockGroups(shardTopGroups, Sort.RELEVANCE, 0, 5); + assertNotNull(mergedTopGroups); + + assertEquals(singleShardTopGroups.totalHitCount, mergedTopGroups.totalHitCount); + assertEquals(singleShardTopGroups.totalGroupCount, mergedTopGroups.totalGroupCount); + assertEquals(singleShardTopGroups.groups.length, mergedTopGroups.groups.length); + + shardControl.close(); + for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) { + shards[shardIdx].close(); + } + } + public void testTopLevelSort() throws IOException { Shard shard = new Shard();