diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index f5b91182..529d7bd9 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -16,6 +16,8 @@ import dev.restate.sdk.common.HandlerRequest import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.TerminalException import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.kotlin.internal.InsideRunElement +import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun import dev.restate.serde.Serde import dev.restate.serde.SerdeFactory import dev.restate.serde.TypeTag @@ -31,6 +33,7 @@ internal constructor( internal val handlerContext: HandlerContext, internal val contextSerdeFactory: SerdeFactory, ) : WorkflowContext { + override fun key(): String { return this.handlerContext.objectKey() } @@ -39,75 +42,89 @@ internal constructor( return this.handlerContext.request() } - override suspend fun get(key: StateKey): T? = - resolveSerde(key.serdeInfo()) - .let { serde -> - SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap { - it.getOrNull()?.let { serde.deserialize(it) } - } + override suspend fun get(key: StateKey): T? { + checkNotInsideRun() + return resolveSerde(key.serdeInfo()) + .let { serde -> + SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap { + it.getOrNull()?.let { serde.deserialize(it) } } - .await() + } + .await() + } - override suspend fun stateKeys(): Collection = - SingleDurableFutureImpl(handlerContext.getKeys().await()).await() + override suspend fun stateKeys(): Collection { + checkNotInsideRun() + return SingleDurableFutureImpl(handlerContext.getKeys().await()).await() + } override suspend fun set(key: StateKey, value: T) { + checkNotInsideRun() handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await() } override suspend fun clear(key: StateKey<*>) { + checkNotInsideRun() handlerContext.clear(key.name()).await() } override suspend fun clearAll() { + checkNotInsideRun() handlerContext.clearAll().await() } - override suspend fun timer(duration: Duration, name: String?): DurableFuture = - SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {} + override suspend fun timer(duration: Duration, name: String?): DurableFuture { + checkNotInsideRun() + return SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await()) + .map {} + } override suspend fun call( request: Request - ): CallDurableFuture = - resolveSerde(request.getResponseTypeTag()).let { responseSerde -> - val callHandle = - handlerContext - .call( - request.getTarget(), - resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), - request.getIdempotencyKey(), - request.getHeaders()?.entries, - ) - .await() - - val callAsyncResult = - callHandle.callAsyncResult.map { - CompletableFuture.completedFuture(responseSerde.deserialize(it)) - } + ): CallDurableFuture { + checkNotInsideRun() + return resolveSerde(request.getResponseTypeTag()).let { responseSerde -> + val callHandle = + handlerContext + .call( + request.getTarget(), + resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), + request.getIdempotencyKey(), + request.getHeaders()?.entries, + ) + .await() - return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult) - } + val callAsyncResult = + callHandle.callAsyncResult.map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + } + + return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult) + } + } override suspend fun send( request: Request, delay: Duration?, - ): InvocationHandle = - resolveSerde(request.getResponseTypeTag()).let { responseSerde -> - val invocationIdAsyncResult = - handlerContext - .send( - request.getTarget(), - resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), - request.getIdempotencyKey(), - request.getHeaders()?.entries, - delay?.toJavaDuration(), - ) - .await() + ): InvocationHandle { + checkNotInsideRun() + return resolveSerde(request.getResponseTypeTag()).let { responseSerde -> + val invocationIdAsyncResult = + handlerContext + .send( + request.getTarget(), + resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), + request.getIdempotencyKey(), + request.getHeaders()?.entries, + delay?.toJavaDuration(), + ) + .await() - object : BaseInvocationHandle(handlerContext, responseSerde) { - override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() - } + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() } + } + } override fun invocationHandle( invocationId: String, @@ -125,6 +142,7 @@ internal constructor( retryPolicy: RetryPolicy?, block: suspend () -> T, ): DurableFuture { + checkNotInsideRun() val serde: Serde = resolveSerde(typeTag) val coroutineCtx = currentCoroutineContext() val javaRetryPolicy = @@ -138,7 +156,10 @@ internal constructor( .setMaxDuration(it.maxDuration?.toJavaDuration()) } - val scope = CoroutineScope(coroutineCtx + CoroutineName("restate-run-$name")) + val scope = + CoroutineScope( + coroutineCtx + CoroutineName("restate-run-$name") + InsideRunElement.INSTANCE + ) val asyncResult = handlerContext @@ -159,6 +180,7 @@ internal constructor( } override suspend fun awakeable(typeTag: TypeTag): Awakeable { + checkNotInsideRun() val serde: Serde = resolveSerde(typeTag) val awk = handlerContext.awakeable().await() return AwakeableImpl(awk.asyncResult, serde, awk.id) @@ -184,15 +206,19 @@ internal constructor( DurablePromise { val serde: Serde = resolveSerde(key.serdeInfo()) - override suspend fun future(): DurableFuture = - SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap { - serde.deserialize(it) - } + override suspend fun future(): DurableFuture { + checkNotInsideRun() + return SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap { + serde.deserialize(it) + } + } - override suspend fun peek(): Output = - SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await()) - .simpleMap { it.map { serde.deserialize(it) } } - .await() + override suspend fun peek(): Output { + checkNotInsideRun() + return SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await()) + .simpleMap { it.map { serde.deserialize(it) } } + .await() + } } inner class DurablePromiseHandleImpl(private val key: DurablePromiseKey) : @@ -200,6 +226,7 @@ internal constructor( val serde: Serde = resolveSerde(key.serdeInfo()) override suspend fun resolve(payload: T) { + checkNotInsideRun() SingleDurableFutureImpl( handlerContext .resolvePromise( @@ -212,6 +239,7 @@ internal constructor( } override suspend fun reject(reason: String) { + checkNotInsideRun() SingleDurableFutureImpl( handlerContext.rejectPromise(key.name(), TerminalException(reason)).await() ) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt index 917f930c..2d05e48c 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt @@ -14,6 +14,7 @@ import dev.restate.sdk.common.TerminalException import dev.restate.sdk.common.TimeoutException import dev.restate.sdk.endpoint.definition.AsyncResult import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun import dev.restate.serde.Serde import dev.restate.serde.TypeTag import java.util.concurrent.CompletableFuture @@ -32,6 +33,7 @@ internal abstract class BaseDurableFutureImpl : DurableFuture { get() = SelectClauseImpl(this) override suspend fun await(): T { + checkNotInsideRun() return asyncResult().poll().await() } @@ -193,20 +195,25 @@ internal constructor( private val responseSerde: Serde, ) : InvocationHandle { override suspend fun cancel() { + checkNotInsideRun() val ignored = handlerContext.cancelInvocation(invocationId()).await() } - override suspend fun attach(): DurableFuture = - SingleDurableFutureImpl( - handlerContext.attachInvocation(invocationId()).await().map { - CompletableFuture.completedFuture(responseSerde.deserialize(it)) - } - ) + override suspend fun attach(): DurableFuture { + checkNotInsideRun() + return SingleDurableFutureImpl( + handlerContext.attachInvocation(invocationId()).await().map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + } + ) + } - override suspend fun output(): Output = - SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await()) - .simpleMap { it.map { responseSerde.deserialize(it) } } - .await() + override suspend fun output(): Output { + checkNotInsideRun() + return SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await()) + .simpleMap { it.map { responseSerde.deserialize(it) } } + .await() + } } internal class AwakeableImpl @@ -218,13 +225,14 @@ internal constructor(asyncResult: AsyncResult, serde: Serde, override internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) : AwakeableHandle { override suspend fun resolve(typeTag: TypeTag, payload: T) { + checkNotInsideRun() contextImpl.handlerContext .resolveAwakeable(id, contextImpl.resolveAndSerialize(typeTag, payload)) .await() } override suspend fun reject(reason: String) { - return + checkNotInsideRun() contextImpl.handlerContext.rejectAwakeable(id, TerminalException(reason)).await() } } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/InsideRunElement.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/InsideRunElement.kt new file mode 100644 index 00000000..9ea63d9f --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/InsideRunElement.kt @@ -0,0 +1,33 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.currentCoroutineContext + +/** + * Coroutine context element that marks the current coroutine as executing inside a `ctx.run()` + * block. Context methods check for this element and throw [IllegalStateException] if present. + */ +internal class InsideRunElement private constructor() : AbstractCoroutineContextElement(Key) { + companion object Key : CoroutineContext.Key { + val INSTANCE = InsideRunElement() + + suspend fun checkNotInsideRun() { + if (currentCoroutineContext()[Key] != null) { + throw IllegalStateException( + "Cannot invoke context method inside ctx.run(). " + + "The run closure is meant for non-deterministic operations (e.g., HTTP calls, database reads). " + + "You MUST use context methods outside of ctx.run(), check the documentation: https://docs.restate.dev/develop/java/durable-steps#run" + ) + } + } + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 6cf71824..408795fa 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -31,6 +31,8 @@ class ContextImpl implements ObjectContext, WorkflowContext { + private static final ThreadLocal INSIDE_RUN = new ThreadLocal<>(); + private final HandlerContext handlerContext; private final Executor serviceExecutor; private final SerdeFactory serdeFactory; @@ -41,6 +43,15 @@ class ContextImpl implements ObjectContext, WorkflowContext { this.serdeFactory = serdeFactory; } + static void checkNotInsideRun() { + if (Boolean.TRUE.equals(INSIDE_RUN.get())) { + throw new IllegalStateException( + "Cannot invoke context method inside ctx.run(). " + + "The run closure is meant for non-deterministic operations (e.g., HTTP calls, database reads). " + + "You MUST use context methods outside of ctx.run(), check the documentation: https://docs.restate.dev/develop/java/durable-steps#run"); + } + } + @Override public String key() { return handlerContext.objectKey(); @@ -53,6 +64,7 @@ public HandlerRequest request() { @Override public Optional get(StateKey key) { + checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.get(key.name())), serviceExecutor) .mapWithoutExecutor(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) @@ -61,22 +73,26 @@ public Optional get(StateKey key) { @Override public Collection stateKeys() { + checkNotInsideRun(); return Util.awaitCompletableFuture( Util.awaitCompletableFuture(handlerContext.getKeys()).poll()); } @Override public void clear(StateKey key) { + checkNotInsideRun(); Util.awaitCompletableFuture(handlerContext.clear(key.name())); } @Override public void clearAll() { + checkNotInsideRun(); Util.awaitCompletableFuture(handlerContext.clearAll()); } @Override public void set(StateKey key, @NonNull T value) { + checkNotInsideRun(); Util.awaitCompletableFuture( handlerContext.set( key.name(), @@ -86,12 +102,14 @@ public void set(StateKey key, @NonNull T value) { @Override public DurableFuture timer(String name, Duration duration) { + checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.timer(duration, name)), serviceExecutor); } @Override public CallDurableFuture call(Request request) { + checkNotInsideRun(); Slice input = Util.executeOrFail( handlerContext, @@ -121,6 +139,7 @@ public CallDurableFuture call(Request request) { @Override public InvocationHandle send( Request request, @Nullable Duration delay) { + checkNotInsideRun(); Slice input = Util.executeOrFail( handlerContext, @@ -152,6 +171,7 @@ public String invocationId() { @Override public InvocationHandle invocationHandle(String invocationId, TypeTag responseTypeTag) { + checkNotInsideRun(); return new BaseInvocationHandle<>( Util.executeOrFail(handlerContext, () -> serdeFactory.create(responseTypeTag))) { @Override @@ -170,11 +190,13 @@ abstract class BaseInvocationHandle implements InvocationHandle { @Override public void cancel() { + checkNotInsideRun(); Util.awaitCompletableFuture(handlerContext.cancelInvocation(invocationId())); } @Override public DurableFuture attach() { + checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.attachInvocation(invocationId())) .map(s -> CompletableFuture.completedFuture(responseSerde.deserialize(s))), @@ -183,6 +205,7 @@ public DurableFuture attach() { @Override public Output getOutput() { + checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.getInvocationOutput(invocationId())) .map(o -> CompletableFuture.completedFuture(o.map(responseSerde::deserialize))), @@ -194,6 +217,7 @@ public Output getOutput() { @Override public DurableFuture runAsync( String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) { + checkNotInsideRun(); Serde serde = serdeFactory.create(typeTag); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture( @@ -203,12 +227,15 @@ public DurableFuture runAsync( serviceExecutor.execute( () -> { Slice result; + INSIDE_RUN.set(Boolean.TRUE); try { result = serde.serialize(action.get()); } catch (Throwable e) { + INSIDE_RUN.remove(); runCompleter.proposeFailure(e, retryPolicy); return; } + INSIDE_RUN.remove(); runCompleter.proposeSuccess(result); }))), serviceExecutor) @@ -217,6 +244,7 @@ public DurableFuture runAsync( @Override public Awakeable awakeable(TypeTag typeTag) throws TerminalException { + checkNotInsideRun(); Serde serde = serdeFactory.create(typeTag); // Retrieve the awakeable HandlerContext.Awakeable awakeable = Util.awaitCompletableFuture(handlerContext.awakeable()); @@ -228,6 +256,7 @@ public AwakeableHandle awakeableHandle(String id) { return new AwakeableHandle() { @Override public void resolve(TypeTag serde, @NonNull T payload) { + checkNotInsideRun(); Util.awaitCompletableFuture( handlerContext.resolveAwakeable( id, @@ -237,6 +266,7 @@ public void resolve(TypeTag serde, @NonNull T payload) { @Override public void reject(String reason) { + checkNotInsideRun(); Util.awaitCompletableFuture( handlerContext.rejectAwakeable(id, new TerminalException(reason))); } @@ -253,6 +283,7 @@ public DurablePromise promise(DurablePromiseKey key) { return new DurablePromise<>() { @Override public DurableFuture future() { + checkNotInsideRun(); AsyncResult result = Util.awaitCompletableFuture(handlerContext.promise(key.name())); return DurableFuture.fromAsyncResult(result, serviceExecutor) .mapWithoutExecutor(serdeFactory.create(key.serdeInfo())::deserialize); @@ -260,6 +291,7 @@ public DurableFuture future() { @Override public Output peek() { + checkNotInsideRun(); return Util.awaitCompletableFuture( Util.awaitCompletableFuture(handlerContext.peekPromise(key.name())).poll()) .map(serdeFactory.create(key.serdeInfo())::deserialize); @@ -272,6 +304,7 @@ public DurablePromiseHandle promiseHandle(DurablePromiseKey key) { return new DurablePromiseHandle<>() { @Override public void resolve(T payload) throws IllegalStateException { + checkNotInsideRun(); Util.awaitCompletableFuture( Util.awaitCompletableFuture( handlerContext.resolvePromise( @@ -285,6 +318,7 @@ public void resolve(T payload) throws IllegalStateException { @Override public void reject(String reason) throws IllegalStateException { + checkNotInsideRun(); Util.awaitCompletableFuture( Util.awaitCompletableFuture( handlerContext.rejectPromise(key.name(), new TerminalException(reason))) diff --git a/sdk-api/src/main/java/dev/restate/sdk/DurableFuture.java b/sdk-api/src/main/java/dev/restate/sdk/DurableFuture.java index 6daaa050..382640bc 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/DurableFuture.java +++ b/sdk-api/src/main/java/dev/restate/sdk/DurableFuture.java @@ -52,6 +52,7 @@ public abstract class DurableFuture { * @throws TerminalException if this future was completed with a failure */ public final T await() throws TerminalException { + ContextImpl.checkNotInsideRun(); return Util.awaitCompletableFuture(asyncResult().poll()); } diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java index fffcdda3..fe567e03 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java +++ b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java @@ -53,6 +53,7 @@ public UUID nextUUID() { @Override protected int next(int bits) { + ContextImpl.checkNotInsideRun(); return super.next(bits); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 51a08617..5da7edc3 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -114,6 +114,7 @@ public static ProtocolException closedWhileWaitingEntries() { PROTOCOL_VIOLATION_CODE); } + @Deprecated static ProtocolException invalidSideEffectCall() { return new ProtocolException( "A syscall was invoked from within a side effect closure.", diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java index 8851fb4b..581d0bea 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java @@ -45,6 +45,10 @@ protected abstract TestInvocationBuilder awaitAllSideEffectWithSecondFailing( protected abstract TestInvocationBuilder failingSideEffectWithRetryPolicy( String reason, RetryPolicy retryPolicy); + protected abstract TestInvocationBuilder sideEffectGuard(); + + protected abstract TestInvocationBuilder sideEffectGuardAwait(); + protected abstract TestInvocationBuilder instantNow(); protected abstract void assertIsInstant(ByteString bytes); @@ -327,6 +331,49 @@ public Stream definitions() { Protocol.ProposeRunCompletionMessage::getResultCompletionId) .extracting(Protocol.ProposeRunCompletionMessage::getValue) .satisfies(this::assertIsInstant), - msg -> assertThat(msg).isEqualTo(suspensionMessage(1))))); + msg -> assertThat(msg).isEqualTo(suspensionMessage(1)))), + this.sideEffectGuard() + .withInput(startMessage(1), inputCmd()) + .assertingOutput( + msgs -> + assertThat(msgs) + .satisfiesExactly( + msg -> assertThat(msg).isEqualTo(runCmd(1)), + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) + .returns( + (int) MessageType.RunCommandMessage.encode(), + Protocol.ErrorMessage::getRelatedCommandType) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains( + "Cannot invoke context method inside ctx.run()")))) + .named("Side effect guard prevents context usage inside run"), + this.sideEffectGuardAwait() + .withInput(startMessage(1), inputCmd()) + .assertingOutput( + msgs -> + assertThat(msgs) + .satisfiesExactly( + msg -> assertThat(msg).isInstanceOf(Protocol.SleepCommandMessage.class), + msg -> assertThat(msg).isEqualTo(runCmd(2)), + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(2, Protocol.ErrorMessage::getRelatedCommandIndex) + .returns( + (int) MessageType.RunCommandMessage.encode(), + Protocol.ErrorMessage::getRelatedCommandType) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains( + "Cannot invoke context method inside ctx.run()")))) + .named("Side effect guard prevents awaiting durable future inside run")); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java index a73ab385..c06a5cbd 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java @@ -166,6 +166,31 @@ protected TestInvocationBuilder failingSideEffectWithRetryPolicy( }); } + @Override + protected TestInvocationBuilder sideEffectGuard() { + return testDefinitionForService( + "SideEffectGuard", + Serde.VOID, + TestSerdes.STRING, + (ctx, unused) -> { + ctx.run(() -> ctx.sleep(java.time.Duration.ofMillis(100))); + return null; + }); + } + + @Override + protected TestInvocationBuilder sideEffectGuardAwait() { + return testDefinitionForService( + "SideEffectGuardAwait", + Serde.VOID, + TestSerdes.STRING, + (ctx, unused) -> { + DurableFuture timer = ctx.timer("my-sleep", java.time.Duration.ofMillis(100)); + ctx.run(() -> timer.await()); + return null; + }); + } + @Override protected TestInvocationBuilder instantNow() { return testDefinitionForService( diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt index 1d448cee..e0ead976 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt @@ -27,6 +27,7 @@ import dev.restate.serde.kotlinx.typeTag import java.util.* import kotlin.coroutines.coroutineContext import kotlin.time.Clock +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.ExperimentalTime import kotlin.time.Instant import kotlin.time.toJavaInstant @@ -136,6 +137,19 @@ class SideEffectTest : SideEffectTestSuite() { } } + override fun sideEffectGuard() = + testDefinitionForService("SideEffectGuard") { ctx, _: Unit -> + ctx.runBlock { ctx.sleep(100.milliseconds) } + "" + } + + override fun sideEffectGuardAwait() = + testDefinitionForService("SideEffectGuardAwait") { ctx, _: Unit -> + val timer = ctx.timer(100.milliseconds) + ctx.runBlock { timer.await() } + "" + } + @OptIn(ExperimentalTime::class) override fun instantNow() = testDefinitionForService("InstantNow") { ctx, _: Unit -> Clock.Restate.now() }