Skip to content

Commit

Permalink
Check protocol version is supported (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Jan 11, 2024
1 parent d0ad9f1 commit ffaa961
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
MessageLite msg = invocationInput.message();
LOG.trace("Received input message {} {}", msg.getClass(), msg);
if (this.invocationState == InvocationState.WAITING_START) {
MessageHeader.checkProtocolVersion(invocationInput.header());
this.onStart(msg);
} else if (msg instanceof Protocol.CompletionMessage) {
// We check the instance rather than the state, because the user code might still be
Expand Down
23 changes: 19 additions & 4 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

public class MessageHeader {

static final short SUPPORTED_PROTOCOL_VERSION = 0;

static final short VERSION_MASK = 0x03FF;
static final short DONE_FLAG = 0x0001;
static final int REQUIRES_ACK_FLAG = 0x8000;

Expand Down Expand Up @@ -43,10 +46,6 @@ public long encode() {
return res;
}

public MessageHeader copyWithFlags(short flag) {
return new MessageHeader(type, flag, length);
}

public static MessageHeader parse(long encoded) throws ProtocolException {
var ty_code = (short) (encoded >> 48);
var flags = (short) (encoded >> 32);
Expand Down Expand Up @@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) {
}
throw new IllegalStateException();
}

public static void checkProtocolVersion(MessageHeader header) {
if (header.type != MessageType.StartMessage) {
throw new IllegalStateException("Expected StartMessage, got " + header.type);
}

short version = (short) (header.flags & VERSION_MASK);
if (version != SUPPORTED_PROTOCOL_VERSION) {
throw new IllegalStateException(
"Unsupported protocol version "
+ version
+ ", only version "
+ SUPPORTED_PROTOCOL_VERSION
+ " is supported");
}
}
}
13 changes: 13 additions & 0 deletions sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package dev.restate.sdk.core;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.jupiter.api.Test;

Expand All @@ -24,4 +25,16 @@ void requiresAckFlag() {
.encode())
.isEqualTo(0x0C01_8001_0000_0002L);
}

@Test
void checkProtocolVersion() {
int unknownVersion = Integer.MAX_VALUE & MessageHeader.VERSION_MASK;
assertThatThrownBy(
() ->
MessageHeader.checkProtocolVersion(
new MessageHeader(MessageType.StartMessage, unknownVersion, 0)))
.hasMessage(
"Unsupported protocol version %d, only version %d is supported",
unknownVersion, MessageHeader.SUPPORTED_PROTOCOL_VERSION);
}
}

0 comments on commit ffaa961

Please sign in to comment.