Skip to content

Commit 8637158

Browse files
Support recursive updates and compute interactions
Co-authored-by: martinfrancois <[email protected]>
1 parent 3fdd74c commit 8637158

File tree

1 file changed

+147
-61
lines changed

1 file changed

+147
-61
lines changed

junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java

Lines changed: 147 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
import java.util.ArrayList;
2323
import java.util.Comparator;
2424
import java.util.List;
25+
import java.util.Objects;
2526
import java.util.Optional;
2627
import java.util.concurrent.ConcurrentHashMap;
2728
import java.util.concurrent.ConcurrentMap;
29+
import java.util.concurrent.ExecutionException;
30+
import java.util.concurrent.FutureTask;
2831
import java.util.concurrent.atomic.AtomicInteger;
32+
import java.util.function.BiFunction;
2933
import java.util.function.Function;
3034
import java.util.function.Supplier;
3135

@@ -134,8 +138,8 @@ public void close() {
134138
if (this.closeAction != null) {
135139
List<Throwable> failures = new ArrayList<>();
136140
this.storedValues.entrySet().stream() //
137-
.map(e -> e.getValue().evaluateSafely(e.getKey())) //
138-
.filter(it -> it != null && it.value != null) //
141+
.map(e -> EvaluatedValue.createSafely(e.getKey(), e.getValue())) //
142+
.filter(Objects::nonNull) //
139143
.sorted(EvaluatedValue.REVERSE_INSERT_ORDER) //
140144
.forEach(it -> {
141145
try {
@@ -213,11 +217,22 @@ public void close() {
213217
CompositeKey<N> compositeKey = new CompositeKey<>(namespace, key);
214218
StoredValue storedValue = getStoredValue(compositeKey);
215219
if (storedValue == null) {
216-
storedValue = this.storedValues.computeIfAbsent(compositeKey,
217-
__ -> newStoredValue(new MemoizingSupplier(() -> {
220+
var newStoredValue = this.storedValues.compute(compositeKey, //
221+
(__, oldStoredValue) -> {
222+
if (isPresent(oldStoredValue)) {
223+
return oldStoredValue;
224+
}
218225
rejectIfClosed();
219-
return defaultCreator.apply(key);
220-
})));
226+
return newStoredSuppliedNullableValue(new DeferredSupplier(() -> {
227+
rejectIfClosed();
228+
return defaultCreator.apply(key);
229+
}));
230+
});
231+
232+
if (newStoredValue instanceof StoredValue.DeferredValue value) {
233+
value.delegate().run();
234+
}
235+
return requireNonNull(newStoredValue.evaluate());
221236
}
222237
return storedValue.evaluate();
223238
}
@@ -247,15 +262,20 @@ public <K, V> Object computeIfAbsent(N namespace, K key, Function<? super K, ? e
247262
StoredValue newStoredValue = this.storedValues.compute(compositeKey, (__, oldStoredValue) -> {
248263
if (StoredValue.evaluateIfNotNull(oldStoredValue) == null) {
249264
rejectIfClosed();
250-
var computedValue = Preconditions.notNull(defaultCreator.apply(key),
251-
"defaultCreator must not return null");
252-
return newStoredValue(() -> {
265+
return newStoredSuppliedValue(new DeferredSupplier(() -> {
253266
rejectIfClosed();
254-
return computedValue;
255-
});
267+
return Preconditions.notNull(defaultCreator.apply(key), "defaultCreator must not return null");
268+
}));
256269
}
257270
return oldStoredValue;
258271
});
272+
273+
if (newStoredValue instanceof StoredValue.DeferredOptionalValue value) {
274+
var delegate = value.delegate();
275+
delegate.run();
276+
return requireNonNull(delegate.getOrThrow());
277+
}
278+
// put or getOrComputeIfAbsent won the race
259279
return requireNonNull(newStoredValue.evaluate());
260280
}
261281
return result;
@@ -328,7 +348,7 @@ public <K, V> V computeIfAbsent(N namespace, K key, Function<? super K, ? extend
328348
public @Nullable Object put(N namespace, Object key, @Nullable Object value)
329349
throws NamespacedHierarchicalStoreException {
330350
rejectIfClosed();
331-
StoredValue oldValue = this.storedValues.put(new CompositeKey<>(namespace, key), newStoredValue(() -> value));
351+
StoredValue oldValue = this.storedValues.put(new CompositeKey<>(namespace, key), newStoredValue(value));
332352
return StoredValue.evaluateIfNotNull(oldValue);
333353
}
334354

@@ -372,13 +392,24 @@ public <K, V> V computeIfAbsent(N namespace, K key, Function<? super K, ? extend
372392
return castToRequiredType(key, value, requiredType);
373393
}
374394

375-
private StoredValue newStoredValue(Supplier<@Nullable Object> value) {
376-
return new StoredValue(this.insertOrderSequence.getAndIncrement(), value);
395+
private StoredValue.Value newStoredValue(@Nullable Object value) {
396+
var sequenceNumber = insertOrderSequence.getAndIncrement();
397+
return new StoredValue.Value(sequenceNumber, value);
398+
}
399+
400+
private StoredValue.DeferredValue newStoredSuppliedNullableValue(DeferredSupplier supplier) {
401+
var sequenceNumber = insertOrderSequence.getAndIncrement();
402+
return new StoredValue.DeferredValue(sequenceNumber, supplier);
403+
}
404+
405+
private StoredValue.DeferredOptionalValue newStoredSuppliedValue(DeferredSupplier supplier) {
406+
var sequenceNumber = insertOrderSequence.getAndIncrement();
407+
return new StoredValue.DeferredOptionalValue(sequenceNumber, supplier);
377408
}
378409

379410
private @Nullable StoredValue getStoredValue(CompositeKey<N> compositeKey) {
380411
StoredValue storedValue = this.storedValues.get(compositeKey);
381-
if (storedValue != null) {
412+
if (isPresent(storedValue)) {
382413
return storedValue;
383414
}
384415
if (this.parentStore != null) {
@@ -387,6 +418,10 @@ private StoredValue newStoredValue(Supplier<@Nullable Object> value) {
387418
return null;
388419
}
389420

421+
private static boolean isPresent(@Nullable StoredValue value) {
422+
return value != null && value.isPresent();
423+
}
424+
390425
private <T> @Nullable T castToRequiredType(Object key, @Nullable Object value, Class<T> requiredType) {
391426
Preconditions.notNull(requiredType, "requiredType must not be null");
392427
if (value == null) {
@@ -425,89 +460,140 @@ private record CompositeKey<N>(N namespace, Object key) {
425460

426461
}
427462

428-
private record StoredValue(int order, Supplier<@Nullable Object> supplier) {
463+
private interface StoredValue {
429464

430-
private <N> @Nullable EvaluatedValue<N> evaluateSafely(CompositeKey<N> compositeKey) {
431-
try {
432-
return new EvaluatedValue<>(compositeKey, this.order, evaluate());
465+
int order();
466+
467+
@Nullable
468+
Object evaluate();
469+
470+
boolean isPresent();
471+
472+
static @Nullable Object evaluateIfNotNull(@Nullable StoredValue value) {
473+
return value != null ? value.evaluate() : null;
474+
}
475+
476+
record Value(int order, @Nullable Object value) implements StoredValue {
477+
478+
@Override
479+
public @Nullable Object evaluate() {
480+
return value;
433481
}
434-
catch (Throwable t) {
435-
UnrecoverableExceptions.rethrowIfUnrecoverable(t);
436-
return null;
482+
483+
@Override
484+
public boolean isPresent() {
485+
return true;
437486
}
438487
}
439488

440-
private @Nullable Object evaluate() {
441-
return this.supplier.get();
442-
}
489+
record DeferredValue(int order, DeferredSupplier delegate) implements StoredValue {
443490

444-
static @Nullable Object evaluateIfNotNull(@Nullable StoredValue value) {
445-
return value != null ? value.evaluate() : null;
491+
@Override
492+
public @Nullable Object evaluate() {
493+
return delegate.getOrThrow();
494+
}
495+
496+
@Override
497+
public boolean isPresent() {
498+
return true;
499+
}
446500
}
447501

502+
record DeferredOptionalValue(int order, DeferredSupplier delegate) implements StoredValue {
503+
504+
@Override
505+
public @Nullable Object evaluate() {
506+
return delegate.get();
507+
}
508+
509+
@Override
510+
public boolean isPresent() {
511+
return evaluate() != null;
512+
}
513+
}
448514
}
449515

450-
private record EvaluatedValue<N>(CompositeKey<N> compositeKey, int order, @Nullable Object value) {
516+
private record EvaluatedValue<N>(CompositeKey<N> compositeKey, int order, Object value) {
517+
518+
private static <N> @Nullable EvaluatedValue<N> createSafely(CompositeKey<N> compositeKey, StoredValue value) {
519+
try {
520+
var evaluatedValue = value.evaluate();
521+
if (evaluatedValue == null) {
522+
return null;
523+
}
524+
return new EvaluatedValue<>(compositeKey, value.order(), evaluatedValue);
525+
}
526+
catch (Throwable t) {
527+
UnrecoverableExceptions.rethrowIfUnrecoverable(t);
528+
return null;
529+
}
530+
}
451531

452532
private static final Comparator<EvaluatedValue<?>> REVERSE_INSERT_ORDER = comparing(
453533
(EvaluatedValue<?> it) -> it.order).reversed();
454534

455535
private void close(CloseAction<N> closeAction) throws Throwable {
456-
if (this.value != null) {
457-
closeAction.close(this.compositeKey.namespace, this.compositeKey.key, this.value);
458-
}
536+
closeAction.close(this.compositeKey.namespace, this.compositeKey.key, this.value);
459537
}
460538

461539
}
462540

463541
/**
464-
* Thread-safe {@link Supplier} that memoizes the result of calling its
465-
* delegate and ensures it is called at most once.
466-
*
467-
* <p>If the delegate throws an exception, it is stored and rethrown every
468-
* time {@link #get()} is called.
469-
*
470-
* @see StoredValue
542+
* Deferred computation that can be added to the store.
543+
* <p>
544+
* This allows values to be computed outside the
545+
* {@link ConcurrentHashMap#compute(Object, BiFunction)} calls and
546+
* prevents recursive updates.
471547
*/
472-
private static class MemoizingSupplier implements Supplier<@Nullable Object> {
548+
private static final class DeferredSupplier implements Supplier<@Nullable Object> {
473549

474-
private static final Object NO_VALUE_SET = new Object();
550+
private final FutureTask<@Nullable Object> task;
475551

476-
private final Supplier<@Nullable Object> delegate;
477-
478-
@Nullable
479-
private volatile Object value = NO_VALUE_SET;
552+
DeferredSupplier(Supplier<@Nullable Object> delegate) {
553+
this.task = new FutureTask<>(delegate::get);
554+
}
480555

481-
private MemoizingSupplier(Supplier<@Nullable Object> delegate) {
482-
this.delegate = delegate;
556+
void run() {
557+
this.task.run();
483558
}
484559

485560
@Override
486561
public @Nullable Object get() {
487-
if (this.value == NO_VALUE_SET) {
488-
computeValue();
562+
try {
563+
return this.task.get();
489564
}
490-
if (this.value instanceof Failure failure) {
491-
throw throwAsUncheckedException(failure.throwable);
565+
catch (InterruptedException e) {
566+
Thread.currentThread().interrupt();
567+
throw throwAsUncheckedException(e);
568+
}
569+
catch (ExecutionException e) {
570+
Throwable t = e.getCause();
571+
if (t == null) {
572+
t = e;
573+
}
574+
UnrecoverableExceptions.rethrowIfUnrecoverable(t);
575+
return null;
492576
}
493-
return this.value;
494577
}
495578

496-
private synchronized void computeValue() {
579+
@Nullable
580+
Object getOrThrow() {
497581
try {
498-
if (this.value == NO_VALUE_SET) {
499-
this.value = this.delegate.get();
500-
}
582+
return this.task.get();
501583
}
502-
catch (Throwable t) {
503-
this.value = new Failure(t);
584+
catch (InterruptedException e) {
585+
Thread.currentThread().interrupt();
586+
throw throwAsUncheckedException(e);
587+
}
588+
catch (ExecutionException e) {
589+
Throwable t = e.getCause();
590+
if (t == null) {
591+
t = e;
592+
}
504593
UnrecoverableExceptions.rethrowIfUnrecoverable(t);
594+
throw throwAsUncheckedException(t);
505595
}
506596
}
507-
508-
private record Failure(Throwable throwable) {
509-
}
510-
511597
}
512598

513599
/**

0 commit comments

Comments
 (0)