Skip to content

Commit ffaa961

Browse files
Check protocol version is supported (#199)
1 parent d0ad9f1 commit ffaa961

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
134134
MessageLite msg = invocationInput.message();
135135
LOG.trace("Received input message {} {}", msg.getClass(), msg);
136136
if (this.invocationState == InvocationState.WAITING_START) {
137+
MessageHeader.checkProtocolVersion(invocationInput.header());
137138
this.onStart(msg);
138139
} else if (msg instanceof Protocol.CompletionMessage) {
139140
// We check the instance rather than the state, because the user code might still be

sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
public class MessageHeader {
1616

17+
static final short SUPPORTED_PROTOCOL_VERSION = 0;
18+
19+
static final short VERSION_MASK = 0x03FF;
1720
static final short DONE_FLAG = 0x0001;
1821
static final int REQUIRES_ACK_FLAG = 0x8000;
1922

@@ -43,10 +46,6 @@ public long encode() {
4346
return res;
4447
}
4548

46-
public MessageHeader copyWithFlags(short flag) {
47-
return new MessageHeader(type, flag, length);
48-
}
49-
5049
public static MessageHeader parse(long encoded) throws ProtocolException {
5150
var ty_code = (short) (encoded >> 48);
5251
var flags = (short) (encoded >> 32);
@@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) {
127126
}
128127
throw new IllegalStateException();
129128
}
129+
130+
public static void checkProtocolVersion(MessageHeader header) {
131+
if (header.type != MessageType.StartMessage) {
132+
throw new IllegalStateException("Expected StartMessage, got " + header.type);
133+
}
134+
135+
short version = (short) (header.flags & VERSION_MASK);
136+
if (version != SUPPORTED_PROTOCOL_VERSION) {
137+
throw new IllegalStateException(
138+
"Unsupported protocol version "
139+
+ version
140+
+ ", only version "
141+
+ SUPPORTED_PROTOCOL_VERSION
142+
+ " is supported");
143+
}
144+
}
130145
}

sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
package dev.restate.sdk.core;
1010

1111
import static org.assertj.core.api.Assertions.assertThat;
12+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1213

1314
import org.junit.jupiter.api.Test;
1415

@@ -24,4 +25,16 @@ void requiresAckFlag() {
2425
.encode())
2526
.isEqualTo(0x0C01_8001_0000_0002L);
2627
}
28+
29+
@Test
30+
void checkProtocolVersion() {
31+
int unknownVersion = Integer.MAX_VALUE & MessageHeader.VERSION_MASK;
32+
assertThatThrownBy(
33+
() ->
34+
MessageHeader.checkProtocolVersion(
35+
new MessageHeader(MessageType.StartMessage, unknownVersion, 0)))
36+
.hasMessage(
37+
"Unsupported protocol version %d, only version %d is supported",
38+
unknownVersion, MessageHeader.SUPPORTED_PROTOCOL_VERSION);
39+
}
2740
}

0 commit comments

Comments
 (0)