diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 9212e3484c3..4eb48090da9 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -94,7 +94,6 @@ final class HttpServerHandler extends ChannelInboundHandlerAdapter implements Ht private static final Logger logger = LoggerFactory.getLogger(HttpServerHandler.class); - private static final CompletableFuture[] EMPTY_FUTURES = {}; private static final String ALLOWED_METHODS_STRING = HttpMethod.knownMethods().stream().map(HttpMethod::name).collect(Collectors.joining(",")); diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCall.java b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCall.java index 650c35b923f..65cba359b34 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCall.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCall.java @@ -29,6 +29,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.linecorp.armeria.common.HttpData; @@ -111,8 +112,9 @@ public abstract class AbstractServerCall extends ServerCall { private final String clientAcceptEncoding; private final boolean autoCompression; + @VisibleForTesting @Nullable - private final Executor blockingExecutor; + final Executor blockingExecutor; private final InternalGrpcExceptionHandler exceptionHandler; // Only set once. @@ -382,6 +384,11 @@ protected final void onRequestComplete() { } protected final void invokeOnReady() { + if (blockingExecutor != null && cancelled) { + // Do not call listener.onReady() if the call is cancelled after + // this task was scheduled to blockingTaskExecutor. + return; + } try { if (listener != null) { listener.onReady(); @@ -392,6 +399,11 @@ protected final void invokeOnReady() { } private void invokeOnMessage(I request, boolean halfClose) { + if (blockingExecutor != null && cancelled) { + // Do not call listener.onMessage() if the call is cancelled after + // this task was scheduled to blockingTaskExecutor. + return; + } try (SafeCloseable ignored = ctx.push()) { assert listener != null; listener.onMessage(request); @@ -404,6 +416,11 @@ private void invokeOnMessage(I request, boolean halfClose) { } protected final void invokeHalfClose() { + if (blockingExecutor != null && cancelled) { + // Do not call listener.onHalfClose() if the call is cancelled after + // this task was scheduled to blockingTaskExecutor. + return; + } try (SafeCloseable ignored = ctx.push()) { assert listener != null; listener.onHalfClose(); diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCallTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCallTest.java new file mode 100644 index 00000000000..4f692b4b104 --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCallTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2025 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.server.grpc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.grpc.GrpcClients; +import com.linecorp.armeria.common.FilteredHttpRequest; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; +import testing.grpc.Messages.StreamingInputCallRequest; +import testing.grpc.Messages.StreamingInputCallResponse; +import testing.grpc.TestServiceGrpc; +import testing.grpc.TestServiceGrpc.TestServiceStub; + +class AbstractServerCallTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + final AtomicReference> serverCallCaptor = new AtomicReference<>(); + final GrpcService grpcService = + GrpcService.builder() + .useBlockingTaskExecutor(true) + .useClientTimeoutHeader(false) + .addService(ServerInterceptors.intercept( + new FooTestServiceImpl(), + new ServerInterceptor() { + + @Override + public Listener interceptCall( + ServerCall call, Metadata headers, + ServerCallHandler next) { + serverCallCaptor.set(call); + return next.startCall(call, headers); + } + })) + .build(); + sb.service(grpcService); + sb.decorator((delegate, ctx, req) -> { + final FilteredHttpRequest newReq = new FilteredHttpRequest(req) { + @Override + protected void beforeSubscribe(Subscriber subscriber, + Subscription subscription) { + // This is called right before + // blockingExecutor.execute(() -> invokeOnMessage(request, endOfStream)); + // in AbstractServerCall. + // https://github.com/line/armeria/blob/0960d091bfc7f350c17e68f57cc627de584b9705/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/AbstractServerCall.java#L363 + final ServerCall serverCall = serverCallCaptor.get(); + assertThat(serverCall).isInstanceOf(AbstractServerCall.class); + ((AbstractServerCall) serverCall).blockingExecutor.execute(() -> { + // invokeOnMessage is not called until the request is cancelled. + await().until(serverCall::isCancelled); + // Now, AbstractServerCall.invokeOnMessage() is called and it doesn't call + // listener.onMessage() because the request is cancelled. + }); + } + + @Override + protected HttpObject filter(HttpObject obj) { + return obj; + } + }; + ctx.updateRequest(newReq); + return delegate.serve(ctx, newReq); + }); + sb.requestTimeoutMillis(100); + } + }; + + private static final AtomicBoolean isOnNextCalled = new AtomicBoolean(); + + @Test + void onMessageIsNotCalledWhenRequestCancelled() throws InterruptedException { + final TestServiceStub testServiceStub = GrpcClients.newClient(server.httpUri(), TestServiceStub.class); + final CompletableFuture future = new CompletableFuture<>(); + final StreamObserver streamingInputCallRequestStreamObserver = + testServiceStub.streamingInputCall(new StreamObserver() { + @Override + public void onNext(StreamingInputCallResponse value) {} + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onCompleted() { + } + }); + streamingInputCallRequestStreamObserver.onNext(StreamingInputCallRequest.newBuilder().build()); + assertThatThrownBy(future::get).hasCauseInstanceOf(StatusRuntimeException.class) + .hasMessageContaining("CANCELLED"); + // Sleep additional 1 second to make sure that the onNext() is not called. + Thread.sleep(1000); + assertThat(isOnNextCalled).isFalse(); + } + + private static class FooTestServiceImpl extends TestServiceGrpc.TestServiceImplBase { + + @Override + public StreamObserver streamingInputCall( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(StreamingInputCallRequest value) { + // If this method is called that means listener.onMessage() in AbstractServerCall is called. + isOnNextCalled.set(true); + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() {} + }; + } + } +} diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServiceJsonSchemaTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServiceJsonSchemaTest.java index 069560c1899..54aaed3b113 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServiceJsonSchemaTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServiceJsonSchemaTest.java @@ -28,7 +28,6 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; -import com.google.protobuf.Descriptors.ServiceDescriptor; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; @@ -48,10 +47,6 @@ class GrpcDocServiceJsonSchemaTest { - private static final ServiceDescriptor TEST_SERVICE_DESCRIPTOR = - testing.grpc.Test.getDescriptor() - .findServiceByName("TestService"); - private static class TestService extends TestServiceImplBase { @Override public void unaryCallWithAllDifferentParameterTypes(