diff --git a/documentation/modules/ROOT/partials/release-notes/release-notes-6.1.0-M2.adoc b/documentation/modules/ROOT/partials/release-notes/release-notes-6.1.0-M2.adoc index ba9b0fb0fa68..84025ff9904f 100644 --- a/documentation/modules/ROOT/partials/release-notes/release-notes-6.1.0-M2.adoc +++ b/documentation/modules/ROOT/partials/release-notes/release-notes-6.1.0-M2.adoc @@ -16,6 +16,8 @@ repository on GitHub. ==== Bug Fixes * Clarify `TestDescriptor` implementation requirements. +* Enable recursive updates when using `NamespacedHierarchicalStore.computeIfAbsent(N, K, Function)`. + This provides parity with the deprecated `NamespacedHierarchicalStore.getOrComputeIfAbsent(N, K, Function)` [[v6.1.0-M2-junit-platform-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes @@ -34,7 +36,8 @@ repository on GitHub. [[v6.1.0-M2-junit-jupiter-bug-fixes]] ==== Bug Fixes -* ❓ +* Enable recursive updates when using `ExtensionContext.Store.computeIfAbsent(K, Function, Class)`. + This provides parity with the deprecated `ExtensionContext.Store.getOrComputeIfAbsent(K, Function, Class)` [[v6.1.0-M2-junit-jupiter-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes diff --git a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java index baeeaed7f122..8560c69ec00e 100644 --- a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java +++ b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java @@ -22,10 +22,14 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -134,8 +138,8 @@ public void close() { if (this.closeAction != null) { List failures = new ArrayList<>(); this.storedValues.entrySet().stream() // - .map(e -> e.getValue().evaluateSafely(e.getKey())) // - .filter(it -> it != null && it.value != null) // + .map(e -> EvaluatedValue.createSafely(e.getKey(), e.getValue())) // + .filter(Objects::nonNull) // .sorted(EvaluatedValue.REVERSE_INSERT_ORDER) // .forEach(it -> { try { @@ -210,14 +214,29 @@ public void close() { public @Nullable Object getOrComputeIfAbsent(N namespace, K key, Function defaultCreator) { Preconditions.notNull(defaultCreator, "defaultCreator must not be null"); - CompositeKey compositeKey = new CompositeKey<>(namespace, key); - StoredValue storedValue = getStoredValue(compositeKey); - if (storedValue == null) { - storedValue = this.storedValues.computeIfAbsent(compositeKey, - __ -> newStoredValue(new MemoizingSupplier(() -> { - rejectIfClosed(); - return defaultCreator.apply(key); - }))); + var compositeKey = new CompositeKey<>(namespace, key); + var currentStoredValue = getStoredValue(compositeKey); + if (currentStoredValue != null) { + return currentStoredValue.evaluate(); + } + var candidateStoredValue = newStoredSuppliedNullableValue(() -> { + rejectIfClosed(); + return defaultCreator.apply(key); + }); + var storedValue = storedValues.compute(compositeKey, // + (__, oldStoredValue) -> { + // guard against race conditions, repeated from getStoredValue + // this filters out failures inserted by computeIfAbsent + if (StoredValue.isNonNullAndPresent(oldStoredValue)) { + return oldStoredValue; + } + rejectIfClosed(); + return candidateStoredValue; + }); + + // Only the caller that created the candidateStoredValue may run it + if (candidateStoredValue.equals(storedValue)) { + return candidateStoredValue.execute(); } return storedValue.evaluate(); } @@ -240,25 +259,49 @@ public void close() { @API(status = MAINTAINED, since = "6.0") public Object computeIfAbsent(N namespace, K key, Function defaultCreator) { Preconditions.notNull(defaultCreator, "defaultCreator must not be null"); - CompositeKey compositeKey = new CompositeKey<>(namespace, key); - StoredValue storedValue = getStoredValue(compositeKey); - var result = StoredValue.evaluateIfNotNull(storedValue); - if (result == null) { - StoredValue newStoredValue = this.storedValues.compute(compositeKey, (__, oldStoredValue) -> { - if (StoredValue.evaluateIfNotNull(oldStoredValue) == null) { - rejectIfClosed(); - var computedValue = Preconditions.notNull(defaultCreator.apply(key), - "defaultCreator must not return null"); - return newStoredValue(() -> { - rejectIfClosed(); - return computedValue; - }); - } + var compositeKey = new CompositeKey<>(namespace, key); + var currentStoredValue = getStoredValue(compositeKey); + var result = StoredValue.evaluateIfNotNull(currentStoredValue); + if (result != null) { + return result; + } + var candidateStoredValue = newStoredSuppliedValue(() -> { + rejectIfClosed(); + return Preconditions.notNull(defaultCreator.apply(key), "defaultCreator must not return null"); + }); + var storedValue = storedValues.compute(compositeKey, (__, oldStoredValue) -> { + // guard against race conditions + // computeIfAbsent replaces both null and absent values + if (StoredValue.evaluateIfNotNull(oldStoredValue) != null) { return oldStoredValue; - }); - return requireNonNull(newStoredValue.evaluate()); + } + rejectIfClosed(); + return candidateStoredValue; + }); + + // In a race condition either put, getOrComputeIfAbsent, or another + // computeIfAbsent call put a non-null value in the store + if (!candidateStoredValue.equals(storedValue)) { + return requireNonNull(storedValue.evaluate()); + } + // Only the caller that created the candidateStoredValue may run it + // and see the exception. + Object newResult = candidateStoredValue.execute(); + // DeferredOptionalValue is quite heavy, replace with lighter container + if (candidateStoredValue.isPresent()) { + storedValues.computeIfPresent(compositeKey, compareAndPut(storedValue, newStoredValue(newResult))); } - return result; + return newResult; + } + + private static BiFunction, StoredValue, StoredValue> compareAndPut(StoredValue expectedValue, + StoredValue newValue) { + return (__, storedValue) -> { + if (!expectedValue.equals(storedValue)) { + return storedValue; + } + return newValue; + }; } /** @@ -328,7 +371,7 @@ public V computeIfAbsent(N namespace, K key, Function(namespace, key), newStoredValue(() -> value)); + StoredValue oldValue = this.storedValues.put(new CompositeKey<>(namespace, key), newStoredValue(value)); return StoredValue.evaluateIfNotNull(oldValue); } @@ -372,13 +415,24 @@ public V computeIfAbsent(N namespace, K key, Function value) { - return new StoredValue(this.insertOrderSequence.getAndIncrement(), value); + private StoredValue.Value newStoredValue(@Nullable Object value) { + var sequenceNumber = insertOrderSequence.getAndIncrement(); + return new StoredValue.Value(sequenceNumber, value); + } + + private StoredValue.DeferredValue newStoredSuppliedNullableValue(Supplier<@Nullable Object> supplier) { + var sequenceNumber = insertOrderSequence.getAndIncrement(); + return new StoredValue.DeferredValue(sequenceNumber, supplier); + } + + private StoredValue.DeferredOptionalValue newStoredSuppliedValue(Supplier supplier) { + var sequenceNumber = insertOrderSequence.getAndIncrement(); + return new StoredValue.DeferredOptionalValue(sequenceNumber, supplier); } private @Nullable StoredValue getStoredValue(CompositeKey compositeKey) { StoredValue storedValue = this.storedValues.get(compositeKey); - if (storedValue != null) { + if (StoredValue.isNonNullAndPresent(storedValue)) { return storedValue; } if (this.parentStore != null) { @@ -425,89 +479,197 @@ private record CompositeKey(N namespace, Object key) { } - private record StoredValue(int order, Supplier<@Nullable Object> supplier) { + private interface StoredValue { - private @Nullable EvaluatedValue evaluateSafely(CompositeKey compositeKey) { - try { - return new EvaluatedValue<>(compositeKey, this.order, evaluate()); - } - catch (Throwable t) { - UnrecoverableExceptions.rethrowIfUnrecoverable(t); - return null; - } - } + int order(); - private @Nullable Object evaluate() { - return this.supplier.get(); - } + @Nullable + Object evaluate(); + + boolean isPresent(); static @Nullable Object evaluateIfNotNull(@Nullable StoredValue value) { return value != null ? value.evaluate() : null; } - } + static boolean isNonNullAndPresent(@Nullable StoredValue value) { + return value != null && value.isPresent(); + } - private record EvaluatedValue(CompositeKey compositeKey, int order, @Nullable Object value) { + /** + * May contain {@code null} or a value, never an exception. + */ + final class Value implements StoredValue { + private final int order; + private final @Nullable Object value; - private static final Comparator> REVERSE_INSERT_ORDER = comparing( - (EvaluatedValue it) -> it.order).reversed(); + Value(int order, @Nullable Object value) { + this.order = order; + this.value = value; + } - private void close(CloseAction closeAction) throws Throwable { - if (this.value != null) { - closeAction.close(this.compositeKey.namespace, this.compositeKey.key, this.value); + @Override + public @Nullable Object evaluate() { + return value; + } + + @Override + public boolean isPresent() { + return true; + } + + @Override + public int order() { + return order; } } - } + /** + * May eventually contain {@code null} or a value or an exception. + */ + final class DeferredValue implements StoredValue { + private final int order; + private final DeferredSupplier delegate; - /** - * Thread-safe {@link Supplier} that memoizes the result of calling its - * delegate and ensures it is called at most once. - * - *

If the delegate throws an exception, it is stored and rethrown every - * time {@link #get()} is called. - * - * @see StoredValue - */ - private static class MemoizingSupplier implements Supplier<@Nullable Object> { + DeferredValue(int order, Supplier<@Nullable Object> delegate) { + this.order = order; + this.delegate = new DeferredSupplier(delegate); + } - private static final Object NO_VALUE_SET = new Object(); + @Override + public @Nullable Object evaluate() { + return delegate.getOrThrow(); + } - private final Supplier<@Nullable Object> delegate; + @Override + public boolean isPresent() { + return true; + } - @Nullable - private volatile Object value = NO_VALUE_SET; + @Nullable + Object execute() { + delegate.run(); + return delegate.getOrThrow(); + } - private MemoizingSupplier(Supplier<@Nullable Object> delegate) { - this.delegate = delegate; + @Override + public int order() { + return order; + } } - @Override - public @Nullable Object get() { - if (this.value == NO_VALUE_SET) { - computeValue(); + /** + * May eventually contain a value or an exception, never {@code null}. + */ + final class DeferredOptionalValue implements StoredValue { + private final int order; + private final DeferredSupplier delegate; + + DeferredOptionalValue(int order, Supplier delegate) { + this.order = order; + this.delegate = new DeferredSupplier(delegate); + } + + @Override + public @Nullable Object evaluate() { + return delegate.get(); } - if (this.value instanceof Failure failure) { - throw throwAsUncheckedException(failure.throwable); + + @Override + public boolean isPresent() { + return evaluate() != null; + } + + Object execute() { + delegate.run(); + // Delegate does not produce null + return requireNonNull(delegate.getOrThrow()); + } + + @Override + public int order() { + return order; } - return this.value; } + } + + private record EvaluatedValue(CompositeKey compositeKey, int order, Object value) { - private synchronized void computeValue() { + private static @Nullable EvaluatedValue createSafely(CompositeKey compositeKey, StoredValue value) { try { - if (this.value == NO_VALUE_SET) { - this.value = this.delegate.get(); + var evaluatedValue = value.evaluate(); + if (evaluatedValue == null) { + return null; } + return new EvaluatedValue<>(compositeKey, value.order(), evaluatedValue); } catch (Throwable t) { - this.value = new Failure(t); UnrecoverableExceptions.rethrowIfUnrecoverable(t); + return null; } } - private record Failure(Throwable throwable) { + private static final Comparator> REVERSE_INSERT_ORDER = comparing( + (EvaluatedValue it) -> it.order).reversed(); + + private void close(CloseAction closeAction) throws Throwable { + closeAction.close(this.compositeKey.namespace, this.compositeKey.key, this.value); + } + + } + + /** + * Deferred computation that can be added to the store. + *

+ * This allows values to be computed outside the + * {@link ConcurrentHashMap#compute(Object, BiFunction)} calls and + * prevents recursive updates. + */ + static final class DeferredSupplier { + + private final FutureTask<@Nullable Object> task; + + DeferredSupplier(Supplier delegate) { + this.task = new FutureTask<>(delegate::get); + } + + void run() { + task.run(); } + @Nullable + Object get() { + try { + return task.get(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw throwAsUncheckedException(e); + } + catch (ExecutionException e) { + // non-null guaranteed by FutureTask + var cause = requireNonNull(e.getCause()); + UnrecoverableExceptions.rethrowIfUnrecoverable(cause); + return null; + } + } + + @Nullable + Object getOrThrow() { + try { + return task.get(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw throwAsUncheckedException(e); + } + catch (ExecutionException e) { + // non-null guaranteed by FutureTask + var cause = requireNonNull(e.getCause()); + UnrecoverableExceptions.rethrowIfUnrecoverable(cause); + throw throwAsUncheckedException(cause); + } + } } /** diff --git a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java index a239db0cae2e..3e199bcbfa8c 100644 --- a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java +++ b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java @@ -16,6 +16,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.platform.commons.test.ConcurrencyTestingUtils.executeConcurrently; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; @@ -25,6 +26,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import java.io.Serial; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -82,6 +84,13 @@ void valueIsComputedIfAbsent() { assertEquals(value, store.get(namespace, key)); } + @Test + void valueIsComputedIfNull() { + assertNull(store.put(namespace, key, null)); + assertEquals(value, store.computeIfAbsent(namespace, key, __ -> value)); + assertEquals(value, store.get(namespace, key)); + } + @SuppressWarnings("deprecation") @Test void valueIsNotComputedIfPresentLocally() { @@ -316,22 +325,42 @@ void computeIfAbsentWithTypeSafetyAndPrimitiveValueType() { @SuppressWarnings("deprecation") @Test void getOrComputeIfAbsentWithExceptionThrowingCreatorFunction() { - var e = assertThrows(RuntimeException.class, () -> store.getOrComputeIfAbsent(namespace, key, __ -> { - throw new RuntimeException("boom"); + var e = assertThrows(ComputeException.class, () -> store.getOrComputeIfAbsent(namespace, key, __ -> { + throw new ComputeException("boom"); })); - assertSame(e, assertThrows(RuntimeException.class, () -> store.get(namespace, key))); - assertSame(e, assertThrows(RuntimeException.class, () -> store.remove(namespace, key))); + assertSame(e, assertThrows(ComputeException.class, () -> store.get(namespace, key))); + assertSame(e, assertThrows(ComputeException.class, () -> store.remove(namespace, key))); } @Test void computeIfAbsentWithExceptionThrowingCreatorFunction() { - assertThrows(RuntimeException.class, () -> store.computeIfAbsent(namespace, key, __ -> { - throw new RuntimeException("boom"); + assertThrows(ComputeException.class, () -> store.computeIfAbsent(namespace, key, __ -> { + throw new ComputeException("boom"); })); assertNull(store.get(namespace, key)); assertNull(store.remove(namespace, key)); } + @SuppressWarnings("deprecation") + @Test + void getOrComputeIfAbsentDoesNotSeeComputeIfAbsentWithExceptionThrowingCreatorFunction() { + assertThrows(ComputeException.class, () -> store.computeIfAbsent(namespace, key, __ -> { + throw new ComputeException("boom"); + })); + assertNull(store.get(namespace, key)); + assertEquals(value, store.getOrComputeIfAbsent(namespace, key, __ -> value)); + } + + @SuppressWarnings("deprecation") + @Test + void computeIfAbsentSeesGetOrComputeIfAbsentWithExceptionThrowingCreatorFunction() { + assertThrows(ComputeException.class, () -> store.getOrComputeIfAbsent(namespace, key, __ -> { + throw new ComputeException("boom"); + })); + assertThrows(ComputeException.class, () -> store.get(namespace, key)); + assertThrows(ComputeException.class, () -> store.computeIfAbsent(namespace, key, __ -> value)); + } + @Test void removeWithTypeSafetyAndInvalidRequiredTypeThrowsException() { Integer key = 42; @@ -416,15 +445,56 @@ void simulateRaceConditionInComputeIfAbsent() throws Exception { assertEquals(1, counter.get()); assertThat(values).hasSize(threads).containsOnly(1); } + + @SuppressWarnings("deprecation") + @Test + void updateRecursivelyGetOrComputeIfAbsent() { + try (var localStore = new NamespacedHierarchicalStore<>(null)) { + var value = localStore.getOrComputeIfAbsent(namespace, new CollidingKey("a"), // + a -> requireNonNull(localStore.getOrComputeIfAbsent(namespace, new CollidingKey("b"), // + b -> "enigma"))); + assertEquals("enigma", value); + } + } + + @Test + void updateRecursivelyComputeIfAbsent() { + try (var localStore = new NamespacedHierarchicalStore<>(null)) { + var value = localStore.computeIfAbsent(namespace, new CollidingKey("a"), // + a -> localStore.computeIfAbsent(namespace, new CollidingKey("b"), // + b -> "enigma")); + assertEquals("enigma", value); + } + } + + private record CollidingKey(String value) { + + @Override + public int hashCode() { + return 42; + } + } } @Nested class InheritedValuesTests { + @SuppressWarnings("deprecation") @Test - void valueFromParentIsVisible() { + void presentValueFromParentIsPresent() { parentStore.put(namespace, key, value); assertEquals(value, store.get(namespace, key)); + assertEquals(value, store.getOrComputeIfAbsent(namespace, key, __ -> "enigma")); + assertEquals(value, store.computeIfAbsent(namespace, key, __ -> "enigma")); + } + + @SuppressWarnings("deprecation") + @Test + void absentValueFromParentIsOverriddenByComputeIfAbsent() { + parentStore.put(namespace, key, null); + assertNull(store.get(namespace, key)); + assertNull(store.getOrComputeIfAbsent(namespace, key, __ -> value)); + assertEquals(value, store.computeIfAbsent(namespace, key, __ -> value)); } @Test @@ -522,19 +592,19 @@ void doesNotCallCloseActionForNullValues() { @Test void doesNotCallCloseActionForValuesThatThrowExceptionsDuringCleanup() throws Throwable { store.put(namespace, "key1", "value1"); - assertThrows(RuntimeException.class, () -> store.computeIfAbsent(namespace, "key2", __ -> { - throw new RuntimeException("boom"); + assertThrows(ComputeException.class, () -> store.computeIfAbsent(namespace, "key2", __ -> { + throw new ComputeException("boom"); })); - assertThrows(RuntimeException.class, () -> store.getOrComputeIfAbsent(namespace, "key2", __ -> { - throw new RuntimeException("boom"); + assertThrows(ComputeException.class, () -> store.getOrComputeIfAbsent(namespace, "key3", __ -> { + throw new ComputeException("boom"); })); - store.put(namespace, "key3", "value3"); + store.put(namespace, "key4", "value4"); store.close(); assertClosed(); var inOrder = inOrder(closeAction); - inOrder.verify(closeAction).close(namespace, "key3", "value3"); + inOrder.verify(closeAction).close(namespace, "key4", "value4"); inOrder.verify(closeAction).close(namespace, "key1", "value1"); inOrder.verifyNoMoreInteractions(); } @@ -663,6 +733,68 @@ private void assertClosed() { } + @Nested + class DeferredSupplierTests { + + @Test + void getCanBeInterrupted() { + var supplier = new NamespacedHierarchicalStore.DeferredSupplier(() -> { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + throw new ComputeException(e); + } + return value; + }); + Thread.currentThread().interrupt(); + assertThrows(InterruptedException.class, () -> { + supplier.get(); + }); + assertTrue(Thread.interrupted()); + } + + @Test + void getThrowsIfUnrecoverable() { + var supplier = new NamespacedHierarchicalStore.DeferredSupplier(() -> { + throw new OutOfMemoryError("boom"); + }); + supplier.run(); + assertThrows(OutOfMemoryError.class, () -> { + supplier.get(); + }); + } + + @Test + void getOrThrowThrowsIfUnrecoverable() { + var supplier = new NamespacedHierarchicalStore.DeferredSupplier(() -> { + throw new OutOfMemoryError("boom"); + }); + supplier.run(); + assertThrows(OutOfMemoryError.class, () -> { + supplier.getOrThrow(); + }); + } + + @Test + void getOrThrowCanBeInterrupted() { + var supplier = new NamespacedHierarchicalStore.DeferredSupplier(() -> { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + throw new ComputeException(e); + } + return value; + }); + Thread.currentThread().interrupt(); + assertThrows(InterruptedException.class, () -> { + supplier.getOrThrow(); + }); + assertTrue(Thread.interrupted()); + } + } + private static Object createObject(String display) { return new Object() { @@ -672,4 +804,21 @@ public String toString() { } }; } + + /** + * To avoid confusion with other Runtime exceptions that can be thrown. + */ + private static final class ComputeException extends RuntimeException { + + @Serial + private static final long serialVersionUID = 1L; + + ComputeException(String msg) { + super(msg); + } + + ComputeException(InterruptedException e) { + super(e); + } + } }