Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion .claude/skills/translate-from-shared-core/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ These are structurally very similar between both codebases:
| `src/vm/transitions/journal.rs` | Methods across `ReplayingState.java` and `ProcessingState.java` |
| `src/vm/transitions/async_results.rs` | Methods in `ReplayingState.java`, `ProcessingState.java`, `AsyncResultsState.java` |
| `src/vm/transitions/terminal.rs` | `hitError()`/`hitSuspended()` methods on `State` interface |
| `src/vm/errors.rs` | `ProtocolException.java` |
| `src/vm/errors.rs` | `ProtocolException.java` (factory methods, not separate error classes) |
| `src/error.rs` (CommandMetadata, NotificationMetadata) | `CommandMetadata.java` (record); notification metadata is built as strings inline |
| `src/service_protocol/` | `MessageDecoder.java`, `MessageEncoder.java`, `MessageType.java`, `ServiceProtocol.java` |

### Command processing patterns
Expand Down Expand Up @@ -165,6 +166,39 @@ Java test inputs are built with `ProtoUtils` helpers (`startMessage()`, `inputCm

**Key implication**: When a Rust commit adds a new VM-level test, in Java you typically need to add a handler-level test in the appropriate test suite, not a direct state machine test.

### Test translation details

When translating Rust VM tests to Java:

1. **Identify the right test suite**: Match the Rust test module to the Java abstract test suite:
- `src/tests/failures.rs` (journal_mismatch) → `StateMachineFailuresTestSuite`
- `src/tests/async_result.rs` → `AsyncResultTestSuite`
- `src/tests/run.rs` → `SideEffectTestSuite`
- `src/tests/state.rs` → `StateTestSuite` / `EagerStateTestSuite`

2. **Add abstract method + test definition**: Add the abstract handler method to the suite, then add test definitions using `withInput(...)` and assertion patterns like `assertingOutput(containsOnly(errorMessage(...)))` or `expectingOutput(...)`.

3. **Implement in both Java and Kotlin**: The suite is extended in both:
- `sdk-core/src/test/java/dev/restate/sdk/core/javaapi/<TestName>.java`
- `sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/<TestName>.kt`

4. **Kotlin API differences**:
- Must `import dev.restate.sdk.kotlin.*` for reified extension functions (`runAsync`, `runBlock`, etc.)
- `ctx.runAsync<String>(name) { ... }` (reified, not `ctx.runAsync(name, String.class, ...)`)
- `ctx.awakeable(TestSerdes.STRING)` (needs a serde, not `String::class.java`)
- `ctx.timer(0.milliseconds)` (uses Kotlin Duration)
- Handler factories: `testDefinitionForService<Unit, String?>("Name") { ctx, _: Unit -> ... }`

5. **Cancel signal is always included**: The `HandlerContextImpl` automatically appends `CANCEL_HANDLE` (handle=1, mapping to `SignalId(1)`) to every `doProgress` call. This matches Rust's `CoreVM.do_progress` which appends `cancel_signal_handle`. So in test assertions, the cancel signal notification ID will always be part of the awaited notifications.

6. **ProtoUtils helpers**: Use `startMessage(n)`, `inputCmd()`, `runCmd(completionId, name)`, `suspensionMessage(completionIds...)`, etc. For messages without helpers (e.g., `SleepCommandMessage`, `SleepCompletionNotificationMessage`), build them directly with the protobuf builders.

### Shared utilities

- `Util.awakeableIdStr(invocationId, signalId)` — computes the awakeable ID string from invocation ID and signal ID. Used in both `StateMachineImpl` (for creating awakeables) and `ReplayingState` (for error messages).
- `StateMachineImpl.CANCEL_SIGNAL_ID` — the signal ID for the built-in cancel signal (value: 1). Package-private, available via static import.
- Java 17 target — do NOT use switch pattern matching (`case Type t ->`) in Java source; use `instanceof` chains instead.

## Step 4: Apply the translation

1. **Read the affected Java files first** before making changes
Expand Down
27 changes: 27 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.core.generated.protocol.Protocol;
import dev.restate.sdk.core.statemachine.NotificationId;
import java.util.List;
import java.util.Map;

public class ProtocolException extends RuntimeException {

Expand Down Expand Up @@ -133,6 +135,31 @@ public static ProtocolException unauthorized(Throwable e) {
return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e);
}

public static ProtocolException uncompletedDoProgressDuringReplay(
List<NotificationId> sortedNotificationIds,
Map<NotificationId, String> notificationDescriptions) {
var sb = new StringBuilder();
sb.append(
"Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n");
sb.append(
"'Awaiting a future' could not be replayed. This usually means the code was mutated adding an 'await' without registering a new service revision.\n");
sb.append("Notifications awaited on this await point:");
for (var notificationId : sortedNotificationIds) {
sb.append("\n - ");
String description = notificationDescriptions.get(notificationId);
if (description != null) {
sb.append(description);
} else if (notificationId instanceof NotificationId.CompletionId completionId) {
sb.append("completion id ").append(completionId.id());
} else if (notificationId instanceof NotificationId.SignalId signalId) {
sb.append("signal [").append(signalId.id()).append("]");
} else if (notificationId instanceof NotificationId.SignalName signalName) {
sb.append("signal '").append(signalName.name()).append("'");
}
}
return new ProtocolException(sb.toString(), JOURNAL_MISMATCH_CODE);
}

public static ProtocolException unsupportedFeature(
String featureName,
Protocol.ServiceProtocolVersion requiredVersion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core.statemachine;

import static dev.restate.sdk.core.statemachine.StateMachineImpl.CANCEL_SIGNAL_ID;
import static dev.restate.sdk.core.statemachine.Util.byteStringToSlice;

import com.google.protobuf.ByteString;
Expand All @@ -19,13 +20,45 @@
import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

final class ReplayingState implements State {

private static final Logger LOG = LogManager.getLogger(ReplayingState.class);

/**
* Comparator for notification IDs in error messages. Orders: completions first (by id), then
* named signals (by name), then signal IDs (by id, with cancel signal last).
*/
private static final Comparator<NotificationId> NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH =
Comparator.<NotificationId>comparingInt(
id -> {
if (id instanceof NotificationId.CompletionId) return 0;
if (id instanceof NotificationId.SignalName) return 1;
return 2;
})
.thenComparing(
(a, b) -> {
if (a instanceof NotificationId.CompletionId ac
&& b instanceof NotificationId.CompletionId bc) {
return Integer.compare(ac.id(), bc.id());
}
if (a instanceof NotificationId.SignalName an
&& b instanceof NotificationId.SignalName bn) {
return an.name().compareTo(bn.name());
}
if (a instanceof NotificationId.SignalId as_
&& b instanceof NotificationId.SignalId bs) {
boolean aIsCancel = as_.id() == CANCEL_SIGNAL_ID;
boolean bIsCancel = bs.id() == CANCEL_SIGNAL_ID;
if (aIsCancel != bIsCancel) return aIsCancel ? 1 : -1;
return Integer.compare(as_.id(), bs.id());
}
return 0;
});

private final Deque<MessageLite> commandsToProcess;
private final AsyncResultsState asyncResultsState;
private final RunState runState;
Expand Down Expand Up @@ -68,12 +101,65 @@ public DoProgressResponse doProgress(List<Integer> awaitingOn, StateContext stat
return DoProgressResponse.AnyCompleted.INSTANCE;
}

if (stateContext.isInputClosed()) {
this.hitSuspended(notificationIds, stateContext);
ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE);
// This assertion proves the user mutated the code, adding an await point.
//
// During replay, we transition to processing AFTER replaying all COMMANDS.
// If we reach this point, none of the previous checks succeeded, meaning we don't have
// enough notifications to complete this await point. But if this await cannot be completed
// during replay, then no progress should have been made afterward, meaning there should be
// no more commands to replay. However, we ARE still replaying, which means there ARE commands
// to replay after this await point.
//
// This contradiction proves the code was mutated: an await must have been added after
// the journal was originally created.

// Prepare error metadata to make it easier to debug
Map<NotificationId, String> knownNotificationMetadata = new HashMap<>();
CommandRelationship relatedCommand = null;

// Collect run info
for (int handle : awaitingOn) {
RunState.Run runInfo = runState.getRunInfo(handle);
if (runInfo != null) {
var notifId = asyncResultsState.mustResolveNotificationHandle(handle);
knownNotificationMetadata.put(
notifId,
MessageType.RunCommandMessage.name()
+ " '"
+ runInfo.commandName()
+ "' (command index "
+ runInfo.commandIndex()
+ ")");
relatedCommand =
new CommandRelationship.Specific(
runInfo.commandIndex(), CommandType.RUN, runInfo.commandName());
}
}

// For awakeables and cancellation, add descriptions
for (var notifId : notificationIds) {
if (notifId instanceof NotificationId.SignalId signalId) {
if (signalId.id() == CANCEL_SIGNAL_ID) {
knownNotificationMetadata.put(notifId, "Cancellation");
} else if (signalId.id() > 16) {
knownNotificationMetadata.put(
notifId,
"Awakeable " + Util.awakeableIdStr(stateContext.getStartInfo().id(), signalId.id()));
}
}
}

return DoProgressResponse.ReadFromInput.INSTANCE;
this.hitError(
ProtocolException.uncompletedDoProgressDuringReplay(
notificationIds.stream()
.sorted(NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH)
.collect(Collectors.toList()),
knownNotificationMetadata),
relatedCommand,
null,
stateContext);
ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE);
return null; // unreachable
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public void insertRunToExecute(int handle, int commandIndex, String commandName)
return null;
}

public @Nullable Run getRunInfo(int handle) {
return runs.get(handle);
}

public boolean anyExecuting(Collection<Integer> anyHandle) {
return anyHandle.stream()
.anyMatch(h -> runs.containsKey(h) && runs.get(h).state == RunStateInner.Executing);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import dev.restate.sdk.core.ProtocolException;
import dev.restate.sdk.core.generated.protocol.Protocol;
import dev.restate.sdk.endpoint.HeadersAccessor;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
Expand All @@ -34,8 +33,7 @@
class StateMachineImpl implements StateMachine {

private static final Logger LOG = LogManager.getLogger(StateMachineImpl.class);
private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1";
private static final int CANCEL_SIGNAL_ID = 1;
static final int CANCEL_SIGNAL_ID = 1;

// Callbacks
private final CompletableFuture<Void> waitForReadyFuture = new CompletableFuture<>();
Expand Down Expand Up @@ -385,15 +383,7 @@ public Awakeable awakeable() {
.createSignalHandle(new NotificationId.SignalId(signalId), this.stateContext);

// Encode awakeable id
String awakeableId =
AWAKEABLE_IDENTIFIER_PREFIX
+ Base64.getUrlEncoder()
.encodeToString(
this.stateContext
.getStartInfo()
.id()
.concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip()))
.toByteArray());
String awakeableId = Util.awakeableIdStr(this.stateContext.getStartInfo().id(), signalId);

return new Awakeable(awakeableId, signalHandle);
}
Expand Down
12 changes: 12 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dev.restate.sdk.core.generated.protocol.Protocol;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -78,6 +79,17 @@ static Duration durationMin(Duration a, Duration b) {
return (a.compareTo(b) <= 0) ? a : b;
}

private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1";

static String awakeableIdStr(ByteString invocationId, int signalId) {
return AWAKEABLE_IDENTIFIER_PREFIX
+ Base64.getUrlEncoder()
.encodeToString(
invocationId
.concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip()))
.toByteArray());
}

/**
* Returns a string representation of a command message.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder;
import static dev.restate.sdk.core.statemachine.ProtoUtils.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.STRING;

import dev.restate.sdk.core.generated.protocol.Protocol;
import dev.restate.serde.Serde;
Expand All @@ -25,6 +26,12 @@ public abstract class StateMachineFailuresTestSuite implements TestDefinitions.T

protected abstract TestInvocationBuilder sideEffectFailure(Serde<Integer> serde);

protected abstract TestInvocationBuilder awaitRunAfterProgressWasMade();

protected abstract TestInvocationBuilder awaitSleepAfterProgressWasMade();

protected abstract TestInvocationBuilder awaitAwakeableAfterProgressWasMade();

private static final Serde<Integer> FAILING_SERIALIZATION_INTEGER_TYPE_TAG =
Serde.using(
i -> {
Expand Down Expand Up @@ -91,6 +98,72 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
.assertingOutput(
containsOnly(
errorDescriptionStartingWith(IllegalStateException.class.getCanonicalName())))
.named("Serde deserialization error"));
.named("Serde deserialization error"),
// --- Uncompleted doProgress during replay (bad await) tests
this.awaitRunAfterProgressWasMade()
.withInput(
startMessage(4),
inputCmd(),
runCmd(1, "my-side-effect"),
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
Protocol.SleepCompletionNotificationMessage.newBuilder()
.setCompletionId(2)
.setVoid(Protocol.Void.getDefaultInstance())
.build())
.assertingOutput(
containsOnly(
errorMessage(
errorMessage ->
assertThat(errorMessage)
.returns(
ProtocolException.JOURNAL_MISMATCH_CODE,
Protocol.ErrorMessage::getCode)
.extracting(Protocol.ErrorMessage::getMessage, STRING)
.contains("could not be replayed")
.contains("await"))))
.named("Add await on run after progress was made"),
this.awaitSleepAfterProgressWasMade()
.withInput(
startMessage(4),
inputCmd(),
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(1).build(),
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
Protocol.SleepCompletionNotificationMessage.newBuilder()
.setCompletionId(2)
.setVoid(Protocol.Void.getDefaultInstance())
.build())
.assertingOutput(
containsOnly(
errorMessage(
errorMessage ->
assertThat(errorMessage)
.returns(
ProtocolException.JOURNAL_MISMATCH_CODE,
Protocol.ErrorMessage::getCode)
.extracting(Protocol.ErrorMessage::getMessage, STRING)
.contains("could not be replayed")
.contains("await"))))
.named("Add await on sleep after progress was made"),
this.awaitAwakeableAfterProgressWasMade()
.withInput(
startMessage(3),
inputCmd(),
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
Protocol.SleepCompletionNotificationMessage.newBuilder()
.setCompletionId(2)
.setVoid(Protocol.Void.getDefaultInstance())
.build())
.assertingOutput(
containsOnly(
errorMessage(
errorMessage ->
assertThat(errorMessage)
.returns(
ProtocolException.JOURNAL_MISMATCH_CODE,
Protocol.ErrorMessage::getCode)
.extracting(Protocol.ErrorMessage::getMessage, STRING)
.contains("could not be replayed")
.contains("await"))))
.named("Add await on awakeable after progress was made"));
}
}
Loading
Loading