File tree Expand file tree Collapse file tree 3 files changed +33
-4
lines changed
main/java/dev/restate/sdk/core
test/java/dev/restate/sdk/core Expand file tree Collapse file tree 3 files changed +33
-4
lines changed Original file line number Diff line number Diff line change @@ -134,6 +134,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
134
134
MessageLite msg = invocationInput .message ();
135
135
LOG .trace ("Received input message {} {}" , msg .getClass (), msg );
136
136
if (this .invocationState == InvocationState .WAITING_START ) {
137
+ MessageHeader .checkProtocolVersion (invocationInput .header ());
137
138
this .onStart (msg );
138
139
} else if (msg instanceof Protocol .CompletionMessage ) {
139
140
// We check the instance rather than the state, because the user code might still be
Original file line number Diff line number Diff line change 14
14
15
15
public class MessageHeader {
16
16
17
+ static final short SUPPORTED_PROTOCOL_VERSION = 0 ;
18
+
19
+ static final short VERSION_MASK = 0x03FF ;
17
20
static final short DONE_FLAG = 0x0001 ;
18
21
static final int REQUIRES_ACK_FLAG = 0x8000 ;
19
22
@@ -43,10 +46,6 @@ public long encode() {
43
46
return res ;
44
47
}
45
48
46
- public MessageHeader copyWithFlags (short flag ) {
47
- return new MessageHeader (type , flag , length );
48
- }
49
-
50
49
public static MessageHeader parse (long encoded ) throws ProtocolException {
51
50
var ty_code = (short ) (encoded >> 48 );
52
51
var flags = (short ) (encoded >> 32 );
@@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) {
127
126
}
128
127
throw new IllegalStateException ();
129
128
}
129
+
130
+ public static void checkProtocolVersion (MessageHeader header ) {
131
+ if (header .type != MessageType .StartMessage ) {
132
+ throw new IllegalStateException ("Expected StartMessage, got " + header .type );
133
+ }
134
+
135
+ short version = (short ) (header .flags & VERSION_MASK );
136
+ if (version != SUPPORTED_PROTOCOL_VERSION ) {
137
+ throw new IllegalStateException (
138
+ "Unsupported protocol version "
139
+ + version
140
+ + ", only version "
141
+ + SUPPORTED_PROTOCOL_VERSION
142
+ + " is supported" );
143
+ }
144
+ }
130
145
}
Original file line number Diff line number Diff line change 9
9
package dev .restate .sdk .core ;
10
10
11
11
import static org .assertj .core .api .Assertions .assertThat ;
12
+ import static org .assertj .core .api .Assertions .assertThatThrownBy ;
12
13
13
14
import org .junit .jupiter .api .Test ;
14
15
@@ -24,4 +25,16 @@ void requiresAckFlag() {
24
25
.encode ())
25
26
.isEqualTo (0x0C01_8001_0000_0002L );
26
27
}
28
+
29
+ @ Test
30
+ void checkProtocolVersion () {
31
+ int unknownVersion = Integer .MAX_VALUE & MessageHeader .VERSION_MASK ;
32
+ assertThatThrownBy (
33
+ () ->
34
+ MessageHeader .checkProtocolVersion (
35
+ new MessageHeader (MessageType .StartMessage , unknownVersion , 0 )))
36
+ .hasMessage (
37
+ "Unsupported protocol version %d, only version %d is supported" ,
38
+ unknownVersion , MessageHeader .SUPPORTED_PROTOCOL_VERSION );
39
+ }
27
40
}
You can’t perform that action at this time.
0 commit comments