Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Improvements

* GITHUB#15225: Improve package documentation for org.apache.lucene.util. (Syed Mohammad Saad)

* GITHUB#15936: Introduce BlockGroupingCollectorManager to parallelize search when using BlockGroupingCollector. (Binlong Gao)

Optimizations
---------------------
* GITHUB#15681, GITHUB#15833: Replace pre-sized array or empty array with lambda expression to call Collection#toArray. (Zhou Hui)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@
*
* @lucene.experimental
*/

// TODO: TopGroups.merge() won't work with TopGroups returned by this collector, because
// each block will be on a different shard. Add a specialized merge() static method
// to this collector?

public class BlockGroupingCollector extends SimpleCollector {

private int[] pendingSubDocs;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.Weight;

/**
* A {@link CollectorManager} for {@link BlockGroupingCollector} that merges results from multiple
* collectors into a single {@link TopGroups}. This is intended for use with concurrent search,
* where each segment is searched by a separate {@link BlockGroupingCollector}.
*
* <p>Documents must be indexed as blocks using {@link
* org.apache.lucene.index.IndexWriter#addDocuments IndexWriter.addDocuments()} or {@link
* org.apache.lucene.index.IndexWriter#updateDocuments IndexWriter.updateDocuments()}.
*
* <p>See {@link BlockGroupingCollector} for more details.
*
* @lucene.experimental
*/
public class BlockGroupingCollectorManager
implements CollectorManager<BlockGroupingCollector, TopGroups<?>> {

private final Sort groupSort;
private final int topNGroups;
private final boolean needsScores;
private final Weight lastDocPerGroup;

private final Sort withinGroupSort;
private final int groupOffset;
private final int withinGroupOffset;
private final int maxDocsPerGroup;

private final List<BlockGroupingCollector> collectors;

public BlockGroupingCollectorManager(
Sort groupSort,
int topNGroups,
boolean needsScores,
Weight lastDocPerGroup,
Sort withinGroupSort,
int groupOffset,
int withinGroupOffset,
int maxDocsPerGroup) {
this.groupSort = groupSort;
this.topNGroups = topNGroups;
this.needsScores = needsScores;
this.lastDocPerGroup = lastDocPerGroup;
this.collectors = new ArrayList<>();
this.withinGroupSort = withinGroupSort;
this.groupOffset = groupOffset;
this.withinGroupOffset = withinGroupOffset;
this.maxDocsPerGroup = maxDocsPerGroup;
}

@Override
public BlockGroupingCollector newCollector() throws IOException {
BlockGroupingCollector collector =
new BlockGroupingCollector(groupSort, topNGroups, needsScores, lastDocPerGroup);
collectors.add(collector);
return collector;
}

@Override
public TopGroups<?> reduce(Collection<BlockGroupingCollector> collectors) throws IOException {
if (collectors.isEmpty()) {
return null;
}

if (collectors.size() == 1) {
return collectors
.iterator()
.next()
.getTopGroups(withinGroupSort, groupOffset, withinGroupOffset, maxDocsPerGroup);
}

// Merge results from multiple collectors
List<TopGroups<?>> shardGroupsList = new ArrayList<>();
for (BlockGroupingCollector collector : collectors) {
TopGroups<?> topGroups =
collector.getTopGroups(withinGroupSort, 0, withinGroupOffset, maxDocsPerGroup);
if (topGroups != null) {
shardGroupsList.add(topGroups);
}
}

if (shardGroupsList.isEmpty()) {
return null;
}

return TopGroups.mergeBlockGroups(shardGroupsList, groupSort, groupOffset, topNGroups);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,18 @@ protected TopGroups<?> groupByDocBlock(
final Query endDocsQuery = searcher.rewrite(this.groupEndDocs);
final Weight groupEndDocs =
searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1);
BlockGroupingCollector c =
new BlockGroupingCollector(
BlockGroupingCollectorManager bcm =
new BlockGroupingCollectorManager(
groupSort,
topN,
groupSort.needsScores() || sortWithinGroup.needsScores(),
groupEndDocs);
searcher.search(query, c);
int topNInsideGroup = groupDocsOffset + groupDocsLimit;
return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup);
groupEndDocs,
sortWithinGroup,
groupOffset,
groupDocsOffset,
groupDocsOffset + groupDocsLimit);

return searcher.search(query, bcm);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
*/
package org.apache.lucene.search.grouping;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.NavigableSet;
import java.util.TreeSet;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
Expand Down Expand Up @@ -280,4 +287,123 @@ public static <T> TopGroups<T> merge(
totalMaxScore);
}
}

private record MergedBlockGroup(Object[] topValues, int shardIndex, int groupIndex) {}

private static class GroupComparator implements Comparator<MergedBlockGroup> {
@SuppressWarnings("rawtypes")
public final FieldComparator[] comparators;

public final int[] reversed;

@SuppressWarnings({"rawtypes"})
public GroupComparator(Sort groupSort) {
final SortField[] sortFields = groupSort.getSort();
comparators = new FieldComparator[sortFields.length];
reversed = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
reversed[compIDX] = sortField.getReverse() ? -1 : 1;
}
}

@Override
@SuppressWarnings({"unchecked"})
public int compare(MergedBlockGroup group, MergedBlockGroup other) {
if (group == other) {
return 0;
}
final Object[] groupValues = group.topValues;
final Object[] otherValues = other.topValues;
for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
final int c =
reversed[compIDX]
* comparators[compIDX].compareValues(groupValues[compIDX], otherValues[compIDX]);
if (c != 0) {
return c;
}
}

assert group.shardIndex != other.shardIndex;
return group.shardIndex - other.shardIndex;
}
}

/**
* Merge TopGroups that are partitioned into blocks per shard. This method assumes that within
* each shard, the groups are sorted according to the groupSort.
*
* @param shardGroups list of TopGroups, one per shard.
* @param groupSort The {@link Sort} used to sort the groups. The top sorted document within each
* * group according to groupSort, determines how that group sorts against other groups. This
* * must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE.
* @param groupOffset Which group to start from.
* @param topNGroups How many top groups to keep.
* @return TopGroups instance or null if there are no groups.
*/
public static TopGroups<?> mergeBlockGroups(
List<TopGroups<?>> shardGroups, Sort groupSort, int groupOffset, int topNGroups) {
if (shardGroups.isEmpty()) {
return null;
}

int totalGroupCount = 0;
int totalHitCount = 0;
int totalGroupedHitCount = 0;
for (TopGroups<?> sg : shardGroups) {
totalGroupCount += sg.totalGroupCount;
totalHitCount += sg.totalHitCount;
}

// k-way merge
GroupComparator groupComp = new GroupComparator(groupSort);
NavigableSet<MergedBlockGroup> queue = new TreeSet<>(groupComp);

// init queue
for (int idx = 0; idx < shardGroups.size(); idx++) {
GroupDocs<?> firstGroupDocs = shardGroups.get(idx).groups[0];
queue.add(new MergedBlockGroup(firstGroupDocs.groupSortValues(), idx, 0));
}

float maxScore = shardGroups.get(queue.first().shardIndex).groups[0].maxScore();

final List<GroupDocs<?>> groupDocsList = new ArrayList<>();
int count = 0;
while (!queue.isEmpty()) {
final MergedBlockGroup mergedBlockGroup = queue.pollFirst();
TopGroups<?> shardGroup = shardGroups.get(mergedBlockGroup.shardIndex);

int currentGroupIndex = mergedBlockGroup.groupIndex;
GroupDocs<?> currentGroupDocs = shardGroup.groups[currentGroupIndex];
if (count++ >= groupOffset) {
groupDocsList.add(currentGroupDocs);
totalGroupedHitCount += currentGroupDocs.totalHits().value();
if (groupDocsList.size() == topNGroups) {
break;
}
}

int nextGroupIndex = currentGroupIndex + 1;
if (nextGroupIndex < shardGroup.groups.length) {
GroupDocs<?> nextGroupDocs = shardGroup.groups[nextGroupIndex];
queue.add(
new MergedBlockGroup(
nextGroupDocs.groupSortValues(), mergedBlockGroup.shardIndex, nextGroupIndex));
}
}

@SuppressWarnings({"unchecked"})
GroupDocs<Object>[] groupDocs = (GroupDocs<Object>[]) groupDocsList.toArray(GroupDocs[]::new);

return new TopGroups<>(
new TopGroups<>(
shardGroups.getFirst().groupSort,
shardGroups.getFirst().withinGroupSort,
totalHitCount,
totalGroupedHitCount,
groupDocs,
maxScore),
totalGroupCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,64 @@ public void testSimple() throws IOException {
shard.close();
}

public void testShardedBlockGrouping() throws IOException {
Shard shardControl = new Shard();
int shardCount = random().nextInt(3) + 2;
Shard[] shards = new Shard[shardCount];
for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) {
shards[shardIdx] = new Shard();
// int bookCount = atLeast(20);
for (int bookIdx = 0; bookIdx < 5; bookIdx++) {
List<Document> block = new ArrayList<>();
String bookName = "book" + shardIdx + bookIdx;
// int chapterCount = atLeast(10);
int chapterCount = 2;
for (int j = 0; j < 2; j++) {
Document doc = new Document();
String chapterName = "chapter" + j;
String chapterText = randomText();
doc.add(new TextField("book", bookName, Field.Store.YES));
doc.add(new TextField("chapter", chapterName, Field.Store.YES));
doc.add(new TextField("text", chapterText, Field.Store.NO));
doc.add(new NumericDocValuesField("length", chapterText.length()));
doc.add(new SortedDocValuesField("book", new BytesRef(bookName)));
if (j == chapterCount - 1) {
doc.add(new TextField("blockEnd", "true", Field.Store.NO));
}
block.add(doc);
}
shards[shardIdx].writer.addDocuments(block);
shardControl.writer.addDocuments(block);
}
}

IndexSearcher shardControlIndexSearcher = shardControl.getIndexSearcher();

Query blockEndQuery = new TermQuery(new Term("blockEnd", "true"));
GroupingSearch grouper = new GroupingSearch(blockEndQuery);
grouper.setGroupDocsLimit(10);

Query topLevel = new TermQuery(new Term("text", "grandmother"));
TopGroups<?> singleShardTopGroups = grouper.search(shardControlIndexSearcher, topLevel, 0, 5);

List<TopGroups<?>> shardTopGroups = new ArrayList<>();
for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) {
shardTopGroups.add(grouper.search(shards[shardIdx].getIndexSearcher(), topLevel, 0, 5));
}

TopGroups<?> mergedTopGroups = TopGroups.mergeBlockGroups(shardTopGroups, Sort.RELEVANCE, 0, 5);
assertNotNull(mergedTopGroups);

assertEquals(singleShardTopGroups.totalHitCount, mergedTopGroups.totalHitCount);
assertEquals(singleShardTopGroups.totalGroupCount, mergedTopGroups.totalGroupCount);
assertEquals(singleShardTopGroups.groups.length, mergedTopGroups.groups.length);

shardControl.close();
for (int shardIdx = 0; shardIdx < shardCount; shardIdx++) {
shards[shardIdx].close();
}
}

public void testTopLevelSort() throws IOException {

Shard shard = new Shard();
Expand Down
Loading