Skip to content

Commit 599d95f

Browse files
committed
Add afterMerge() lifecycle hook to KnnVectorsWriter and PerFieldKnnVectorsFormat
Signed-off-by: Andrew Klepchick <aklepchi@amazon.com>
1 parent ac19f0a commit 599d95f

4 files changed

Lines changed: 281 additions & 8 deletions

File tree

lucene/CHANGES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ http://s.apache.org/luceneversions
88
API Changes
99
---------------------
1010

11+
* GITHUB#15935: Add protected afterMerge() lifecycle hook to KnnVectorsWriter and
12+
PerFieldKnnVectorsFormat. Subclasses can override this to release merge-time resources such as
13+
thread pools. The method is called in a finally block within merge(), guaranteeing invocation
14+
even if mergeOneField throws an exception. (MrFlap)
15+
1116
* GITHUB#15763: Deprecate Operations.complement() method. This operation can be slow and is not
1217
recommended for production use. It will be removed in Lucene 12. (Saurabh Singh)
1318

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
8282
/** Called once at the end before close */
8383
public abstract void finish() throws IOException;
8484

85+
/**
86+
* Called after all vector fields have been merged but before {@link #finish()} and {@link
87+
* #close()}. Subclasses can override this to release merge-time resources such as thread pools.
88+
* The default implementation is a no-op.
89+
*
90+
* <p>This method is guaranteed to be called even if {@link #mergeOneField} throws an exception.
91+
*
92+
* @throws IOException if an I/O error occurs during cleanup
93+
*/
94+
protected void afterMerge() throws IOException {}
95+
8596
/**
8697
* Merges the segment vectors for all fields. This default implementation delegates to {@link
8798
* #mergeOneField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
@@ -96,18 +107,22 @@ public final void merge(MergeState mergeState) throws IOException {
96107
}
97108
}
98109

99-
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
100-
if (fieldInfo.hasVectorValues()) {
101-
if (mergeState.infoStream.isEnabled("VV")) {
102-
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
103-
}
110+
try {
111+
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
112+
if (fieldInfo.hasVectorValues()) {
113+
if (mergeState.infoStream.isEnabled("VV")) {
114+
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
115+
}
104116

105-
mergeOneField(fieldInfo, mergeState);
117+
mergeOneField(fieldInfo, mergeState);
106118

107-
if (mergeState.infoStream.isEnabled("VV")) {
108-
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
119+
if (mergeState.infoStream.isEnabled("VV")) {
120+
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
121+
}
109122
}
110123
}
124+
} finally {
125+
afterMerge();
111126
}
112127
finish();
113128
}

lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ public int getMaxDimensions(String fieldName) {
100100
*/
101101
public abstract KnnVectorsFormat getKnnVectorsFormatForField(String field);
102102

103+
/**
104+
* Called after all vector fields have been merged for a segment merge operation. Subclasses can
105+
* override this to release merge-time resources such as thread pools. The default implementation
106+
* is a no-op.
107+
*
108+
* @throws IOException if an I/O error occurs during cleanup
109+
*/
110+
protected void afterMerge() throws IOException {}
111+
103112
private class FieldsWriter extends KnnVectorsWriter {
104113
private final Map<KnnVectorsFormat, WriterAndSuffix> formats;
105114
private final Map<String, Integer> suffixes = new HashMap<>();
@@ -135,6 +144,11 @@ public void finish() throws IOException {
135144
}
136145
}
137146

147+
@Override
148+
protected void afterMerge() throws IOException {
149+
PerFieldKnnVectorsFormat.this.afterMerge();
150+
}
151+
138152
@Override
139153
public void close() throws IOException {
140154
IOUtils.close(formats.values());
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.codecs.perfield;
18+
19+
import java.io.IOException;
20+
import java.util.concurrent.atomic.AtomicInteger;
21+
import org.apache.lucene.codecs.Codec;
22+
import org.apache.lucene.codecs.FilterCodec;
23+
import org.apache.lucene.codecs.KnnVectorsFormat;
24+
import org.apache.lucene.document.Document;
25+
import org.apache.lucene.document.KnnFloatVectorField;
26+
import org.apache.lucene.index.IndexWriter;
27+
import org.apache.lucene.index.IndexWriterConfig;
28+
import org.apache.lucene.index.NoMergePolicy;
29+
import org.apache.lucene.store.Directory;
30+
import org.apache.lucene.tests.analysis.MockAnalyzer;
31+
import org.apache.lucene.tests.util.LuceneTestCase;
32+
import org.apache.lucene.tests.util.TestUtil;
33+
34+
/** Tests for the afterMerge() lifecycle hook on PerFieldKnnVectorsFormat. */
35+
public class TestPerFieldKnnVectorsFormatAfterMerge extends LuceneTestCase {
36+
37+
/** Writes numSegments single-doc segments with a vector field, using NoMergePolicy. */
38+
private void writeSegments(Directory dir, KnnVectorsFormat format, int numSegments)
39+
throws IOException {
40+
IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random()));
41+
iwc.setCodec(codecWithFormat(format));
42+
iwc.setMergePolicy(NoMergePolicy.INSTANCE);
43+
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
44+
for (int i = 0; i < numSegments; i++) {
45+
Document doc = new Document();
46+
doc.add(new KnnFloatVectorField("field", new float[] {i, i + 1, i + 2}));
47+
iw.addDocument(doc);
48+
iw.commit();
49+
}
50+
}
51+
}
52+
53+
private static FilterCodec codecWithFormat(KnnVectorsFormat format) {
54+
Codec defaultCodec = TestUtil.getDefaultCodec();
55+
return new FilterCodec(defaultCodec.getName(), defaultCodec) {
56+
@Override
57+
public KnnVectorsFormat knnVectorsFormat() {
58+
return format;
59+
}
60+
};
61+
}
62+
63+
/** afterMerge() on the format must be called exactly once when a merge completes. */
64+
public void testAfterMergeCalledOnMerge() throws IOException {
65+
AtomicInteger afterMergeCount = new AtomicInteger();
66+
67+
PerFieldKnnVectorsFormat format =
68+
new PerFieldKnnVectorsFormat() {
69+
@Override
70+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
71+
return TestUtil.getDefaultKnnVectorsFormat();
72+
}
73+
74+
@Override
75+
protected void afterMerge() {
76+
afterMergeCount.incrementAndGet();
77+
}
78+
};
79+
80+
try (Directory dir = newDirectory()) {
81+
writeSegments(dir, format, 3);
82+
assertEquals("afterMerge should not be called during flush", 0, afterMergeCount.get());
83+
84+
// Force merge triggers afterMerge
85+
IndexWriterConfig mergeConfig = new IndexWriterConfig(new MockAnalyzer(random()));
86+
mergeConfig.setCodec(codecWithFormat(format));
87+
try (IndexWriter iw = new IndexWriter(dir, mergeConfig)) {
88+
iw.forceMerge(1);
89+
}
90+
91+
assertEquals("afterMerge should be called exactly once per merge", 1, afterMergeCount.get());
92+
}
93+
}
94+
95+
/** afterMerge() must not be called during a normal flush (no merge). */
96+
public void testAfterMergeNotCalledOnFlush() throws IOException {
97+
AtomicInteger afterMergeCount = new AtomicInteger();
98+
99+
PerFieldKnnVectorsFormat format =
100+
new PerFieldKnnVectorsFormat() {
101+
@Override
102+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
103+
return TestUtil.getDefaultKnnVectorsFormat();
104+
}
105+
106+
@Override
107+
protected void afterMerge() {
108+
afterMergeCount.incrementAndGet();
109+
}
110+
};
111+
112+
try (Directory dir = newDirectory()) {
113+
writeSegments(dir, format, 3);
114+
assertEquals("afterMerge must not be called on flush-only writes", 0, afterMergeCount.get());
115+
}
116+
}
117+
118+
/**
119+
* afterMerge() is called in a finally block, so it must fire even when mergeOneField throws. We
120+
* verify this by using a format that wraps the delegate writer to throw during merge, then
121+
* checking that afterMerge() was still invoked.
122+
*/
123+
public void testAfterMergeCalledEvenOnException() throws IOException {
124+
AtomicInteger afterMergeCount = new AtomicInteger();
125+
126+
PerFieldKnnVectorsFormat format =
127+
new PerFieldKnnVectorsFormat() {
128+
@Override
129+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
130+
return new ThrowingOnMergeKnnVectorsFormat(TestUtil.getDefaultKnnVectorsFormat());
131+
}
132+
133+
@Override
134+
protected void afterMerge() {
135+
afterMergeCount.incrementAndGet();
136+
}
137+
};
138+
139+
try (Directory dir = newDirectory()) {
140+
// Write segments using a normal format so data is valid on disk
141+
PerFieldKnnVectorsFormat normalFormat =
142+
new PerFieldKnnVectorsFormat() {
143+
@Override
144+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
145+
return TestUtil.getDefaultKnnVectorsFormat();
146+
}
147+
};
148+
writeSegments(dir, normalFormat, 3);
149+
150+
// Use SerialMergeScheduler so the merge runs on the calling thread and the
151+
// exception propagates directly (instead of being caught by ConcurrentMergeScheduler)
152+
IndexWriterConfig mergeConfig = new IndexWriterConfig(new MockAnalyzer(random()));
153+
mergeConfig.setCodec(codecWithFormat(format));
154+
mergeConfig.setMergeScheduler(new org.apache.lucene.index.SerialMergeScheduler());
155+
IndexWriter iw = new IndexWriter(dir, mergeConfig);
156+
try {
157+
expectThrows(IOException.class, () -> iw.forceMerge(1));
158+
} finally {
159+
// IndexWriter may be in a tragic state after the merge failure, so rollback
160+
try {
161+
iw.rollback();
162+
} catch (
163+
@SuppressWarnings("unused")
164+
Exception ignored) {
165+
// expected — writer may already be closed or in a tragic state
166+
}
167+
}
168+
169+
assertEquals(
170+
"afterMerge must be called even when mergeOneField throws", 1, afterMergeCount.get());
171+
}
172+
}
173+
174+
/**
175+
* A KnnVectorsFormat wrapper that delegates everything normally except mergeOneField, which always
176+
* throws IOException.
177+
*/
178+
private static class ThrowingOnMergeKnnVectorsFormat extends KnnVectorsFormat {
179+
private final KnnVectorsFormat delegate;
180+
181+
ThrowingOnMergeKnnVectorsFormat(KnnVectorsFormat delegate) {
182+
super(delegate.getName());
183+
this.delegate = delegate;
184+
}
185+
186+
@Override
187+
public org.apache.lucene.codecs.KnnVectorsWriter fieldsWriter(
188+
org.apache.lucene.index.SegmentWriteState state) throws IOException {
189+
org.apache.lucene.codecs.KnnVectorsWriter delegateWriter = delegate.fieldsWriter(state);
190+
return new org.apache.lucene.codecs.KnnVectorsWriter() {
191+
@Override
192+
public org.apache.lucene.codecs.KnnFieldVectorsWriter<?> addField(
193+
org.apache.lucene.index.FieldInfo fieldInfo) throws IOException {
194+
return delegateWriter.addField(fieldInfo);
195+
}
196+
197+
@Override
198+
public void flush(int maxDoc, org.apache.lucene.index.Sorter.DocMap sortMap)
199+
throws IOException {
200+
delegateWriter.flush(maxDoc, sortMap);
201+
}
202+
203+
@Override
204+
public void mergeOneField(
205+
org.apache.lucene.index.FieldInfo fieldInfo,
206+
org.apache.lucene.index.MergeState mergeState)
207+
throws IOException {
208+
throw new IOException("simulated merge failure");
209+
}
210+
211+
@Override
212+
public void finish() throws IOException {
213+
delegateWriter.finish();
214+
}
215+
216+
@Override
217+
public void close() throws IOException {
218+
delegateWriter.close();
219+
}
220+
221+
@Override
222+
public long ramBytesUsed() {
223+
return delegateWriter.ramBytesUsed();
224+
}
225+
};
226+
}
227+
228+
@Override
229+
public org.apache.lucene.codecs.KnnVectorsReader fieldsReader(
230+
org.apache.lucene.index.SegmentReadState state) throws IOException {
231+
return delegate.fieldsReader(state);
232+
}
233+
234+
@Override
235+
public int getMaxDimensions(String fieldName) {
236+
return delegate.getMaxDimensions(fieldName);
237+
}
238+
}
239+
}

0 commit comments

Comments
 (0)