diff --git a/.claude/skills/translate-from-shared-core/SKILL.md b/.claude/skills/translate-from-shared-core/SKILL.md index d5a6577e..43c47635 100644 --- a/.claude/skills/translate-from-shared-core/SKILL.md +++ b/.claude/skills/translate-from-shared-core/SKILL.md @@ -130,7 +130,8 @@ These are structurally very similar between both codebases: | `src/vm/transitions/journal.rs` | Methods across `ReplayingState.java` and `ProcessingState.java` | | `src/vm/transitions/async_results.rs` | Methods in `ReplayingState.java`, `ProcessingState.java`, `AsyncResultsState.java` | | `src/vm/transitions/terminal.rs` | `hitError()`/`hitSuspended()` methods on `State` interface | -| `src/vm/errors.rs` | `ProtocolException.java` | +| `src/vm/errors.rs` | `ProtocolException.java` (factory methods, not separate error classes) | +| `src/error.rs` (CommandMetadata, NotificationMetadata) | `CommandMetadata.java` (record); notification metadata is built as strings inline | | `src/service_protocol/` | `MessageDecoder.java`, `MessageEncoder.java`, `MessageType.java`, `ServiceProtocol.java` | ### Command processing patterns @@ -165,6 +166,39 @@ Java test inputs are built with `ProtoUtils` helpers (`startMessage()`, `inputCm **Key implication**: When a Rust commit adds a new VM-level test, in Java you typically need to add a handler-level test in the appropriate test suite, not a direct state machine test. +### Test translation details + +When translating Rust VM tests to Java: + +1. **Identify the right test suite**: Match the Rust test module to the Java abstract test suite: + - `src/tests/failures.rs` (journal_mismatch) → `StateMachineFailuresTestSuite` + - `src/tests/async_result.rs` → `AsyncResultTestSuite` + - `src/tests/run.rs` → `SideEffectTestSuite` + - `src/tests/state.rs` → `StateTestSuite` / `EagerStateTestSuite` + +2. **Add abstract method + test definition**: Add the abstract handler method to the suite, then add test definitions using `withInput(...)` and assertion patterns like `assertingOutput(containsOnly(errorMessage(...)))` or `expectingOutput(...)`. + +3. **Implement in both Java and Kotlin**: The suite is extended in both: + - `sdk-core/src/test/java/dev/restate/sdk/core/javaapi/.java` + - `sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/.kt` + +4. **Kotlin API differences**: + - Must `import dev.restate.sdk.kotlin.*` for reified extension functions (`runAsync`, `runBlock`, etc.) + - `ctx.runAsync(name) { ... }` (reified, not `ctx.runAsync(name, String.class, ...)`) + - `ctx.awakeable(TestSerdes.STRING)` (needs a serde, not `String::class.java`) + - `ctx.timer(0.milliseconds)` (uses Kotlin Duration) + - Handler factories: `testDefinitionForService("Name") { ctx, _: Unit -> ... }` + +5. **Cancel signal is always included**: The `HandlerContextImpl` automatically appends `CANCEL_HANDLE` (handle=1, mapping to `SignalId(1)`) to every `doProgress` call. This matches Rust's `CoreVM.do_progress` which appends `cancel_signal_handle`. So in test assertions, the cancel signal notification ID will always be part of the awaited notifications. + +6. **ProtoUtils helpers**: Use `startMessage(n)`, `inputCmd()`, `runCmd(completionId, name)`, `suspensionMessage(completionIds...)`, etc. For messages without helpers (e.g., `SleepCommandMessage`, `SleepCompletionNotificationMessage`), build them directly with the protobuf builders. + +### Shared utilities + +- `Util.awakeableIdStr(invocationId, signalId)` — computes the awakeable ID string from invocation ID and signal ID. Used in both `StateMachineImpl` (for creating awakeables) and `ReplayingState` (for error messages). +- `StateMachineImpl.CANCEL_SIGNAL_ID` — the signal ID for the built-in cancel signal (value: 1). Package-private, available via static import. +- Java 17 target — do NOT use switch pattern matching (`case Type t ->`) in Java source; use `instanceof` chains instead. + ## Step 4: Apply the translation 1. **Read the affected Java files first** before making changes diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 5da7edc3..572b187a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -12,6 +12,8 @@ import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.core.generated.protocol.Protocol; import dev.restate.sdk.core.statemachine.NotificationId; +import java.util.List; +import java.util.Map; public class ProtocolException extends RuntimeException { @@ -133,6 +135,31 @@ public static ProtocolException unauthorized(Throwable e) { return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e); } + public static ProtocolException uncompletedDoProgressDuringReplay( + List sortedNotificationIds, + Map notificationDescriptions) { + var sb = new StringBuilder(); + sb.append( + "Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n"); + sb.append( + "'Awaiting a future' could not be replayed. This usually means the code was mutated adding an 'await' without registering a new service revision.\n"); + sb.append("Notifications awaited on this await point:"); + for (var notificationId : sortedNotificationIds) { + sb.append("\n - "); + String description = notificationDescriptions.get(notificationId); + if (description != null) { + sb.append(description); + } else if (notificationId instanceof NotificationId.CompletionId completionId) { + sb.append("completion id ").append(completionId.id()); + } else if (notificationId instanceof NotificationId.SignalId signalId) { + sb.append("signal [").append(signalId.id()).append("]"); + } else if (notificationId instanceof NotificationId.SignalName signalName) { + sb.append("signal '").append(signalName.name()).append("'"); + } + } + return new ProtocolException(sb.toString(), JOURNAL_MISMATCH_CODE); + } + public static ProtocolException unsupportedFeature( String featureName, Protocol.ServiceProtocolVersion requiredVersion, diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java index 4e3e6a54..c8897479 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core.statemachine; +import static dev.restate.sdk.core.statemachine.StateMachineImpl.CANCEL_SIGNAL_ID; import static dev.restate.sdk.core.statemachine.Util.byteStringToSlice; import com.google.protobuf.ByteString; @@ -19,6 +20,7 @@ import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -26,6 +28,37 @@ final class ReplayingState implements State { private static final Logger LOG = LogManager.getLogger(ReplayingState.class); + /** + * Comparator for notification IDs in error messages. Orders: completions first (by id), then + * named signals (by name), then signal IDs (by id, with cancel signal last). + */ + private static final Comparator NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH = + Comparator.comparingInt( + id -> { + if (id instanceof NotificationId.CompletionId) return 0; + if (id instanceof NotificationId.SignalName) return 1; + return 2; + }) + .thenComparing( + (a, b) -> { + if (a instanceof NotificationId.CompletionId ac + && b instanceof NotificationId.CompletionId bc) { + return Integer.compare(ac.id(), bc.id()); + } + if (a instanceof NotificationId.SignalName an + && b instanceof NotificationId.SignalName bn) { + return an.name().compareTo(bn.name()); + } + if (a instanceof NotificationId.SignalId as_ + && b instanceof NotificationId.SignalId bs) { + boolean aIsCancel = as_.id() == CANCEL_SIGNAL_ID; + boolean bIsCancel = bs.id() == CANCEL_SIGNAL_ID; + if (aIsCancel != bIsCancel) return aIsCancel ? 1 : -1; + return Integer.compare(as_.id(), bs.id()); + } + return 0; + }); + private final Deque commandsToProcess; private final AsyncResultsState asyncResultsState; private final RunState runState; @@ -68,12 +101,65 @@ public DoProgressResponse doProgress(List awaitingOn, StateContext stat return DoProgressResponse.AnyCompleted.INSTANCE; } - if (stateContext.isInputClosed()) { - this.hitSuspended(notificationIds, stateContext); - ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); + // This assertion proves the user mutated the code, adding an await point. + // + // During replay, we transition to processing AFTER replaying all COMMANDS. + // If we reach this point, none of the previous checks succeeded, meaning we don't have + // enough notifications to complete this await point. But if this await cannot be completed + // during replay, then no progress should have been made afterward, meaning there should be + // no more commands to replay. However, we ARE still replaying, which means there ARE commands + // to replay after this await point. + // + // This contradiction proves the code was mutated: an await must have been added after + // the journal was originally created. + + // Prepare error metadata to make it easier to debug + Map knownNotificationMetadata = new HashMap<>(); + CommandRelationship relatedCommand = null; + + // Collect run info + for (int handle : awaitingOn) { + RunState.Run runInfo = runState.getRunInfo(handle); + if (runInfo != null) { + var notifId = asyncResultsState.mustResolveNotificationHandle(handle); + knownNotificationMetadata.put( + notifId, + MessageType.RunCommandMessage.name() + + " '" + + runInfo.commandName() + + "' (command index " + + runInfo.commandIndex() + + ")"); + relatedCommand = + new CommandRelationship.Specific( + runInfo.commandIndex(), CommandType.RUN, runInfo.commandName()); + } + } + + // For awakeables and cancellation, add descriptions + for (var notifId : notificationIds) { + if (notifId instanceof NotificationId.SignalId signalId) { + if (signalId.id() == CANCEL_SIGNAL_ID) { + knownNotificationMetadata.put(notifId, "Cancellation"); + } else if (signalId.id() > 16) { + knownNotificationMetadata.put( + notifId, + "Awakeable " + Util.awakeableIdStr(stateContext.getStartInfo().id(), signalId.id())); + } + } } - return DoProgressResponse.ReadFromInput.INSTANCE; + this.hitError( + ProtocolException.uncompletedDoProgressDuringReplay( + notificationIds.stream() + .sorted(NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH) + .collect(Collectors.toList()), + knownNotificationMetadata), + relatedCommand, + null, + stateContext); + ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); + return null; // unreachable } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java index b8c359cd..dedae9b6 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java @@ -32,6 +32,10 @@ public void insertRunToExecute(int handle, int commandIndex, String commandName) return null; } + public @Nullable Run getRunInfo(int handle) { + return runs.get(handle); + } + public boolean anyExecuting(Collection anyHandle) { return anyHandle.stream() .anyMatch(h -> runs.containsKey(h) && runs.get(h).state == RunStateInner.Executing); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java index 354daa26..5d9a5ddf 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java @@ -19,7 +19,6 @@ import dev.restate.sdk.core.ProtocolException; import dev.restate.sdk.core.generated.protocol.Protocol; import dev.restate.sdk.endpoint.HeadersAccessor; -import java.nio.ByteBuffer; import java.time.Duration; import java.time.Instant; import java.util.*; @@ -34,8 +33,7 @@ class StateMachineImpl implements StateMachine { private static final Logger LOG = LogManager.getLogger(StateMachineImpl.class); - private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1"; - private static final int CANCEL_SIGNAL_ID = 1; + static final int CANCEL_SIGNAL_ID = 1; // Callbacks private final CompletableFuture waitForReadyFuture = new CompletableFuture<>(); @@ -385,15 +383,7 @@ public Awakeable awakeable() { .createSignalHandle(new NotificationId.SignalId(signalId), this.stateContext); // Encode awakeable id - String awakeableId = - AWAKEABLE_IDENTIFIER_PREFIX - + Base64.getUrlEncoder() - .encodeToString( - this.stateContext - .getStartInfo() - .id() - .concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip())) - .toByteArray()); + String awakeableId = Util.awakeableIdStr(this.stateContext.getStartInfo().id(), signalId); return new Awakeable(awakeableId, signalHandle); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java index 6eab2499..9c14f325 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java @@ -16,6 +16,7 @@ import dev.restate.sdk.core.generated.protocol.Protocol; import java.nio.ByteBuffer; import java.time.Duration; +import java.util.Base64; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -78,6 +79,17 @@ static Duration durationMin(Duration a, Duration b) { return (a.compareTo(b) <= 0) ? a : b; } + private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1"; + + static String awakeableIdStr(ByteString invocationId, int signalId) { + return AWAKEABLE_IDENTIFIER_PREFIX + + Base64.getUrlEncoder() + .encodeToString( + invocationId + .concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip())) + .toByteArray()); + } + /** * Returns a string representation of a command message. * diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java index c3580678..1bd24a83 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java @@ -12,6 +12,7 @@ import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; import dev.restate.sdk.core.generated.protocol.Protocol; import dev.restate.serde.Serde; @@ -25,6 +26,12 @@ public abstract class StateMachineFailuresTestSuite implements TestDefinitions.T protected abstract TestInvocationBuilder sideEffectFailure(Serde serde); + protected abstract TestInvocationBuilder awaitRunAfterProgressWasMade(); + + protected abstract TestInvocationBuilder awaitSleepAfterProgressWasMade(); + + protected abstract TestInvocationBuilder awaitAwakeableAfterProgressWasMade(); + private static final Serde FAILING_SERIALIZATION_INTEGER_TYPE_TAG = Serde.using( i -> { @@ -91,6 +98,72 @@ public Stream definitions() { .assertingOutput( containsOnly( errorDescriptionStartingWith(IllegalStateException.class.getCanonicalName()))) - .named("Serde deserialization error")); + .named("Serde deserialization error"), + // --- Uncompleted doProgress during replay (bad await) tests + this.awaitRunAfterProgressWasMade() + .withInput( + startMessage(4), + inputCmd(), + runCmd(1, "my-side-effect"), + Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(2) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + ProtocolException.JOURNAL_MISMATCH_CODE, + Protocol.ErrorMessage::getCode) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("could not be replayed") + .contains("await")))) + .named("Add await on run after progress was made"), + this.awaitSleepAfterProgressWasMade() + .withInput( + startMessage(4), + inputCmd(), + Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(1).build(), + Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(2) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + ProtocolException.JOURNAL_MISMATCH_CODE, + Protocol.ErrorMessage::getCode) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("could not be replayed") + .contains("await")))) + .named("Add await on sleep after progress was made"), + this.awaitAwakeableAfterProgressWasMade() + .withInput( + startMessage(3), + inputCmd(), + Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(2) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + ProtocolException.JOURNAL_MISMATCH_CODE, + Protocol.ErrorMessage::getCode) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("could not be replayed") + .contains("await")))) + .named("Add await on awakeable after progress was made")); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java index df48edc9..c346b060 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core.javaapi; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; import dev.restate.sdk.common.AbortedExecutionException; @@ -18,6 +19,7 @@ import dev.restate.sdk.core.TestSerdes; import dev.restate.serde.Serde; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; public class StateMachineFailuresTest extends StateMachineFailuresTestSuite { @@ -65,4 +67,43 @@ protected TestInvocationBuilder sideEffectFailure(Serde serde) { return "Francesco"; }); } + + @Override + protected TestInvocationBuilder awaitRunAfterProgressWasMade() { + return testDefinitionForService( + "AwaitRunAfterProgressWasMade", + Serde.VOID, + TestSerdes.STRING, + (ctx, unused) -> { + var runFuture = ctx.runAsync("my-side-effect", String.class, () -> "result"); + runFuture.await(); + return null; + }); + } + + @Override + protected TestInvocationBuilder awaitSleepAfterProgressWasMade() { + return testDefinitionForService( + "AwaitSleepAfterProgressWasMade", + Serde.VOID, + TestSerdes.STRING, + (ctx, unused) -> { + var sleepFuture = ctx.timer(Duration.ZERO); + sleepFuture.await(); + return null; + }); + } + + @Override + protected TestInvocationBuilder awaitAwakeableAfterProgressWasMade() { + return testDefinitionForService( + "AwaitAwakeableAfterProgressWasMade", + Serde.VOID, + TestSerdes.STRING, + (ctx, unused) -> { + var awakeable = ctx.awakeable(String.class); + awakeable.await(); + return null; + }); + } } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt index 94c7da30..3172c5c1 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt @@ -13,11 +13,14 @@ import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.TerminalException import dev.restate.sdk.core.StateMachineFailuresTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.TestSerdes import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.kotlin.* import dev.restate.serde.Serde import java.nio.charset.StandardCharsets import java.util.concurrent.atomic.AtomicInteger +import kotlin.time.Duration.Companion.milliseconds import kotlinx.coroutines.CancellationException class StateMachineFailuresTest : StateMachineFailuresTestSuite() { @@ -56,4 +59,26 @@ class StateMachineFailuresTest : StateMachineFailuresTestSuite() { ctx.runBlock(serde) { 0 } "Francesco" } + + override fun awaitRunAfterProgressWasMade(): TestInvocationBuilder = + testDefinitionForService("AwaitRunAfterProgressWasMade") { ctx, _: Unit -> + val runFuture = ctx.runAsync("my-side-effect") { "result" } + runFuture.await() + null + } + + override fun awaitSleepAfterProgressWasMade(): TestInvocationBuilder = + testDefinitionForService("AwaitSleepAfterProgressWasMade") { ctx, _: Unit -> + val sleepFuture = ctx.timer(0.milliseconds) + sleepFuture.await() + null + } + + override fun awaitAwakeableAfterProgressWasMade(): TestInvocationBuilder = + testDefinitionForService("AwaitAwakeableAfterProgressWasMade") { ctx, _: Unit + -> + val awakeable = ctx.awakeable(TestSerdes.STRING) + awakeable.await() + null + } }