diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/MatchedQueriesPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/MatchedQueriesPhase.java index bf9a5a1f944f9..165d6de41d457 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/MatchedQueriesPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/MatchedQueriesPhase.java @@ -12,6 +12,10 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.util.Bits; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.search.fetch.FetchContext; @@ -46,7 +50,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept } return new FetchSubPhaseProcessor() { - final Map matchingIterators = new HashMap<>(); + final Map matchingIterators = new HashMap<>(); @Override public void setNextReader(LeafReaderContext readerContext) throws IOException { @@ -54,23 +58,79 @@ public void setNextReader(LeafReaderContext readerContext) throws IOException { for (Map.Entry entry : weights.entrySet()) { ScorerSupplier ss = entry.getValue().scorerSupplier(readerContext); if (ss != null) { - Bits matchingBits = Lucene.asSequentialAccessBits(readerContext.reader().maxDoc(), ss); - matchingIterators.put(entry.getKey(), matchingBits); + Scorer scorer = ss.get(0L); + if (scorer != null) { + final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator(); + final DocIdSetIterator iterator; + if (twoPhase == null) { + iterator = scorer.iterator(); + } else { + iterator = twoPhase.approximation(); + } + matchingIterators.put(entry.getKey(), new ScorerAndIterator(scorer, iterator, twoPhase)); + } } } } @Override - public void process(HitContext hitContext) { + public void process(HitContext hitContext) throws IOException{ List matches = new ArrayList<>(); int doc = hitContext.docId(); - for (Map.Entry iterator : matchingIterators.entrySet()) { - if (iterator.getValue().get(doc)) { - matches.add(iterator.getKey()); + for (Map.Entry entry : matchingIterators.entrySet()) { + ScorerAndIterator query = entry.getValue(); + if (query.approximation.docID() < doc) { + query.approximation.advance(doc); + } + if (query.approximation.docID() == doc && (query.twoPhase == null || query.twoPhase.matches())) { + matches.add(entry.getKey()); } } hitContext.hit().matchedQueries(matches.toArray(new String[0])); } }; } + public class ScorerAndIterator { + private final Scorer scorer; + private final DocIdSetIterator approximation; + private final TwoPhaseIterator twoPhase; + public ScorerAndIterator(Scorer scorer, DocIdSetIterator approximation, TwoPhaseIterator twoPhase) { + this.scorer = scorer; + this.approximation = approximation; + this.twoPhase = twoPhase; + } + public Scorer getScorer() { + return scorer; + } + public DocIdSetIterator getApproximation() { + return approximation; + } + public TwoPhaseIterator getTwoPhase() { + return twoPhase; + } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ScorerAndIterator that = (ScorerAndIterator) o; + if (!scorer.equals(that.scorer)) return false; + if (!approximation.equals(that.approximation)) return false; + return twoPhase.equals(that.twoPhase); + } + @Override + public int hashCode() { + int result = scorer.hashCode(); + result = 31 * result + approximation.hashCode(); + result = 31 * result + twoPhase.hashCode(); + return result; + } + @Override + public String toString() { + return "ScorerAndIterator{" + + "scorer=" + scorer + + ", approximation=" + approximation + + ", twoPhase=" + twoPhase + + '}'; + } + } }