diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 98f596141eed..4ce3f9b651f4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -88,8 +88,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.FailoverChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; @@ -114,6 +116,7 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.auth.MoreCallCredentials; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; @@ -381,7 +384,8 @@ private StreamingWorkerHarnessFactoryOutput createFanOutStreamingEngineWorkerHar MemoryMonitor memoryMonitor, GrpcDispatcherClient dispatcherClient) { WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); - ChannelCache channelCache = createChannelCache(options, checkNotNull(configFetcher)); + ChannelCache channelCache = + createChannelCache(options, checkNotNull(configFetcher), dispatcherClient); @SuppressWarnings("methodref.receiver.bound") FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = FanOutStreamingEngineWorkerHarness.create( @@ -804,20 +808,31 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) } private static ChannelCache createChannelCache( - DataflowWorkerHarnessOptions workerOptions, ComputationConfig.Fetcher configFetcher) { + DataflowWorkerHarnessOptions workerOptions, + ComputationConfig.Fetcher configFetcher, + GrpcDispatcherClient dispatcherClient) { ChannelCache channelCache = ChannelCache.create( (currentFlowControlSettings, serviceAddress) -> { - // IsolationChannel will create and manage separate RPC channels to the same - // serviceAddress. + // IsolationChannel wrapping FailoverChannel so that each active RPC gets its own + // FailoverChannel instance. FailoverChannel creates two channels (primary, + // fallback) per active RPC. return IsolationChannel.create( () -> - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), + FailoverChannel.create( + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))), currentFlowControlSettings.getOnReadyThresholdBytes()); }); + configFetcher .getGlobalConfigHandle() .registerConfigObserver( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 75c2b91af603..63ab5379bd49 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -410,15 +410,18 @@ private GlobalDataStreamSender getOrCreateGlobalDataSteam( } private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { + GetWorkRequest.Builder getWorkRequestBuilder = + GetWorkRequest.newBuilder() + .setClientId(jobHeader.getClientId()) + .setJobId(jobHeader.getJobId()) + .setProjectId(jobHeader.getProjectId()) + .setWorkerId(jobHeader.getWorkerId()); + endpoint.workerToken().ifPresent(getWorkRequestBuilder::setBackendWorkerToken); + WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( WindmillConnection.from(endpoint, this::createWindmillStub), - GetWorkRequest.newBuilder() - .setClientId(jobHeader.getClientId()) - .setJobId(jobHeader.getJobId()) - .setProjectId(jobHeader.getProjectId()) - .setWorkerId(jobHeader.getWorkerId()) - .build(), + getWorkRequestBuilder.build(), GetWorkBudget.noBudget(), streamFactory, workItemScheduler, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index 82e66c4b0d74..0d8f75dd816a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -128,7 +128,7 @@ public CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { : randomlySelectNextStub(windmillServiceStubs)); } - ImmutableSet getDispatcherEndpoints() { + public ImmutableSet getDispatcherEndpoints() { return dispatcherStubs.get().dispatcherEndpoints(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java new file mode 100644 index 000000000000..71119e45150a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link ManagedChannel} that wraps a primary and a fallback channel. + * + *

Routes requests to either primary or fallback channel based on two independent failover modes: + * + *

    + *
  • Connection Status Failover: If the primary channel is not ready for 10+ seconds + * (e.g., during network issues), routes to fallback channel. Switches back as soon as the + * primary channel becomes READY again. + *
  • RPC Failover: If primary channel RPC fails with transient errors ({@link + * Status.Code#UNAVAILABLE} or {@link Status.Code#UNKNOWN}), or with {@link + * Status.Code#DEADLINE_EXCEEDED} before receiving any response (indicating the connection was + * never established) and connection status is not READY, switches to fallback channel and + * waits for a 1-hour cooling period before retrying primary. + *
+ */ +@Internal +public final class FailoverChannel extends ManagedChannel { + private static final Logger LOG = LoggerFactory.getLogger(FailoverChannel.class); + private static final AtomicInteger CHANNEL_ID_COUNTER = new AtomicInteger(0); + // Time to wait before retrying the primary channel after an RPC-based fallback. + private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); + private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); + + private final ManagedChannel primary; + private final ManagedChannel fallback; + private final int channelId; + @Nullable private final CallCredentials fallbackCallCredentials; + private final LongSupplier nanoClock; + // Held only during registration to prevent duplicate listener registration. + private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); + // All mutable routing state is consolidated here to ensure related fields are updated atomically. + private final FailoverState state; + + private static final class FailoverState { + // Set when primary's connection state has been unavailable for too long. + @GuardedBy("this") + boolean useFallbackDueToState; + // Set when an RPC on primary fails with an error. + @GuardedBy("this") + boolean useFallbackDueToRPC; + // Timestamp when RPC-based fallback was triggered. Only meaningful when useFallbackDueToRPC + // is true. + @GuardedBy("this") + long lastRPCFallbackTimeNanos; + // Time when primary first became not-ready. -1 when primary is currently READY. + @GuardedBy("this") + long primaryNotReadySinceNanos = -1; + + private final int channelId; + + FailoverState(int channelId) { + this.channelId = channelId; + } + + /** + * Determines whether the next RPC should route to the fallback channel, updating internal state + * as needed. + */ + synchronized boolean computeUseFallback(long nowNanos) { + // Clear RPC-based fallback if the cooling period has elapsed. + if (useFallbackDueToRPC + && nowNanos - lastRPCFallbackTimeNanos >= FALLBACK_COOLING_PERIOD_NANOS) { + useFallbackDueToRPC = false; + LOG.info( + "[channel-{}] Primary channel cooling period elapsed; switching back from fallback.", + channelId); + } + // Check if primary has been not-ready long enough to switch to fallback. + // primaryNotReadySinceNanos is set by the state-change callback when primary is not ready. + if (!useFallbackDueToRPC + && !useFallbackDueToState + && primaryNotReadySinceNanos >= 0 + && nowNanos - primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS) { + useFallbackDueToState = true; + LOG.warn( + "[channel-{}] Primary connection unavailable. Switching to secondary connection.", + channelId); + } + return useFallbackDueToRPC || useFallbackDueToState; + } + + /** + * Starts the not-ready grace period timer. Called by the state-change callback when primary + * transitions to a non-ready state. Has no effect if already tracking or already on fallback. + */ + synchronized void markPrimaryNotReady(long nowNanos) { + if (!useFallbackDueToRPC && !useFallbackDueToState && primaryNotReadySinceNanos < 0) { + primaryNotReadySinceNanos = nowNanos; + } + } + + /** + * Transitions the fallback state. When toFallback is true (RPC failure) it enables RPC-based + * fallback if not already active and returns true so the caller can log the failure details. + * When toFallback is false (primary recovered) it clears all fallback flags and returns true if + * recovery actually changed state, so the caller can log it. + */ + synchronized boolean transitionFallback(boolean toFallback, long nowNanos) { + if (toFallback) { + if (!useFallbackDueToRPC) { + useFallbackDueToRPC = true; + lastRPCFallbackTimeNanos = nowNanos; + // Return true to indicate fallback state was changed and caller should log the event. + return true; + } + // Already in RPC-based fallback, no state change. + return false; + } + // Clear all fallback state as primary has recovered. + boolean wasOnFallback = useFallbackDueToState || useFallbackDueToRPC; + useFallbackDueToState = false; + useFallbackDueToRPC = false; + primaryNotReadySinceNanos = -1; + return wasOnFallback; + } + } + + private FailoverChannel( + ManagedChannel primary, + ManagedChannel fallback, + @Nullable CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + this.primary = primary; + this.fallback = fallback; + this.channelId = CHANNEL_ID_COUNTER.getAndIncrement(); + this.state = new FailoverState(channelId); + this.fallbackCallCredentials = fallbackCallCredentials; + this.nanoClock = nanoClock; + // Register callback to monitor primary channel state changes + registerPrimaryStateChangeListener(); + } + + public static FailoverChannel create( + ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, System::nanoTime); + } + + static FailoverChannel forTest( + ManagedChannel primary, + ManagedChannel fallback, + CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, nanoClock); + } + + @Override + public String authority() { + return primary.authority(); + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + // Read the clock before the synchronized call to avoid holding it under the state lock. + long nowNanos = nanoClock.getAsLong(); + boolean useFallback = state.computeUseFallback(nowNanos); + + if (useFallback) { + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + true, + methodDescriptor.getFullMethodName()); + } + + return new FailoverClientCall<>( + primary.newCall(methodDescriptor, callOptions), + false, + methodDescriptor.getFullMethodName()); + } + + @Override + public ManagedChannel shutdown() { + primary.shutdown(); + fallback.shutdown(); + return this; + } + + @Override + public ManagedChannel shutdownNow() { + primary.shutdownNow(); + fallback.shutdownNow(); + return this; + } + + @Override + public boolean isShutdown() { + return primary.isShutdown() && fallback.isShutdown(); + } + + @Override + public boolean isTerminated() { + return primary.isTerminated() && fallback.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long endTimeNanos = nanoClock.getAsLong() + unit.toNanos(timeout); + boolean primaryTerminated = primary.awaitTermination(timeout, unit); + long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); + return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); + } + + private boolean shouldFallbackBasedOnRPCStatus(Status status, boolean receivedResponse) { + // If a response was received, the connection was healthy and any error is an application-level + // issue, not a connectivity problem. Never failover in this case regardless of status. + if (receivedResponse) { + return false; + } + switch (status.getCode()) { + case UNAVAILABLE: + case UNKNOWN: + case DEADLINE_EXCEEDED: + return true; + default: + return false; + } + } + + private CallOptions applyFallbackCredentials(CallOptions callOptions) { + if (fallbackCallCredentials != null) { + return callOptions.withCallCredentials(fallbackCallCredentials); + } + return callOptions; + } + + private void notifyCallDone( + Status status, boolean isFallback, String methodName, boolean receivedResponse) { + if (!status.isOk() && !isFallback && shouldFallbackBasedOnRPCStatus(status, receivedResponse)) { + if (state.transitionFallback(true, nanoClock.getAsLong())) { + LOG.warn( + "[channel-{}] Primary connection failed for method: {}. Switching to secondary" + + " connection. Status: {}", + channelId, + methodName, + status.getCode()); + } + } else if (isFallback && !status.isOk()) { + LOG.warn( + "[channel-{}] Secondary connection failed for method: {}. Status: {}", + channelId, + methodName, + status.getCode()); + } + } + + private final class FailoverClientCall + extends SimpleForwardingClientCall { + private final boolean isFallback; + private final String methodName; + // Tracks whether any response message was received. Volatile ensures the write in onMessage + // is visible to the read in onClose even if they execute on different threads within gRPC's + // SerializingExecutor. + private volatile boolean receivedResponse = false; + + /** + * @param delegate the underlying ClientCall (either primary or fallback) + * @param isFallback true if {@code delegate} is a fallback channel call, false if it is a + * primary channel call. This flag is inspected by {@link #notifyCallDone} to determine + * whether a failure should trigger switching to the fallback channel (only primary failures + * do). + * @param methodName gRPC method name (for logging) + */ + FailoverClientCall(ClientCall delegate, boolean isFallback, String methodName) { + super(delegate); + this.isFallback = isFallback; + this.methodName = methodName; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + super.start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onMessage(RespT message) { + receivedResponse = true; + super.onMessage(message); + } + + @Override + public void onClose(Status status, Metadata trailers) { + notifyCallDone(status, isFallback, methodName, receivedResponse); + super.onClose(status, trailers); + } + }, + headers); + } + } + + /** Registers callback for primary channel state changes. */ + private void registerPrimaryStateChangeListener() { + if (!stateChangeListenerRegistered.getAndSet(true)) { + try { + ConnectivityState currentState = primary.getState(false); + primary.notifyWhenStateChanged(currentState, this::onPrimaryStateChanged); + } catch (Exception e) { + LOG.warn( + "[channel-{}] Failed to register channel state monitor. Continuing with fallback detection.", + channelId, + e); + stateChangeListenerRegistered.set(false); + } + } + } + + /** Callback invoked when primary channel connectivity state changes. */ + private void onPrimaryStateChanged() { + if (isShutdown() || isTerminated()) { + return; + } + + ConnectivityState newState = primary.getState(false); + // IDLE means the channel was READY but has no active RPCs — treat as healthy. + if (newState == ConnectivityState.READY || newState == ConnectivityState.IDLE) { + if (state.transitionFallback(false, 0)) { + LOG.info( + "[channel-{}] Primary channel recovered; switching back from fallback.", channelId); + } + } else { + // Primary is not ready; start the grace period timer so computeUseFallback can + // switch to fallback once PRIMARY_NOT_READY_WAIT_NANOS elapses. + state.markPrimaryNotReady(nanoClock.getAsLong()); + } + + // Always re-register for next state change (unless shutdown). + if (!isShutdown() && !isTerminated()) { + stateChangeListenerRegistered.set(false); + registerPrimaryStateChangeListener(); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 94c8f4b75957..b5f244f77eb6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -132,7 +132,7 @@ private static WorkItemScheduler noOpProcessWorkItemFn() { getWorkStreamLatencies) -> {}; } - private static GetWorkRequest getWorkRequest(long items, long bytes) { + private static GetWorkRequest getWorkRequest(long items, long bytes, String backendWorkerToken) { return GetWorkRequest.newBuilder() .setJobId(JOB_ID) .setProjectId(PROJECT_ID) @@ -140,6 +140,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) + .setBackendWorkerToken(backendWorkerToken) .build(); } @@ -239,9 +240,22 @@ public void testStreamsStartCorrectly() throws InterruptedException { .distributeBudget( any(), eq(GetWorkBudget.builder().setItems(items).setBytes(bytes).build())); - verify(streamFactory, times(2)) + verify(streamFactory, times(1)) .createDirectGetWorkStream( - any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); + any(), + eq(getWorkRequest(0, 0, workerToken)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); + verify(streamFactory, times(1)) + .createDirectGetWorkStream( + any(), + eq(getWorkRequest(0, 0, workerToken2)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); verify(streamFactory, times(2)).createDirectGetDataStream(any()); verify(streamFactory, times(2)).createDirectCommitWorkStream(any()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java new file mode 100644 index 000000000000..ee72a5d4a993 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall.Listener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +@RunWith(JUnit4.class) +public class FailoverChannelTest { + + private MethodDescriptor methodDescriptor = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) + .setRequestMarshaller(new NoopClientCall.NoopMarshaller()) + .setResponseMarshaller(new NoopClientCall.NoopMarshaller()) + .build(); + + private static FailoverChannel createForTest(ManagedChannel primary, ManagedChannel fallback) { + return FailoverChannel.forTest(primary, fallback, null, System::nanoTime); + } + + /** + * Starts a call on the primary channel, captures the injected listener, and fires onClose with + * the given status. Use this to trigger RPC-based failover in tests. + */ + private void triggerRPCFailure( + FailoverChannel channel, ClientCall underlying, Status status) + throws Exception { + Metadata metadata = new Metadata(); + channel + .newCall(methodDescriptor, CallOptions.DEFAULT) + .start(new NoopClientCall.NoopClientCallListener<>(), metadata); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlying).start(captor.capture(), same(metadata)); + captor.getValue().onClose(status, new Metadata()); + } + + @Test + public void testRPCFailureTriggersFallback() throws Exception { + // RPC failure with UNAVAILABLE should switch to fallback channel. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); + + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + + @Test + public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { + // After RPC failure, channel stays on fallback during cooling period, then returns to primary. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); + + // Within cooling period, still on fallback + time.addAndGet(TimeUnit.MINUTES.toNanos(30)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + + // After cooling period, recovers to primary + time.addAndGet(TimeUnit.MINUTES.toNanos(40)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + + @Test + public void testRPCFallbackClearedByConnectivityRecovery() throws Exception { + // Race condition: RPC failure observed just before connectivity callback fires READY. + // Once the channel goes through unhealthy→healthy, the cooling period must be cancelled + // and traffic must return to primary immediately (not wait 1 hour). + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + when(mockChannel.getState(false)).thenReturn(ConnectivityState.IDLE, ConnectivityState.READY); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // RPC failure results in entering cooling period + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); + + // Still within cooling period, routes to fallback + time.addAndGet(TimeUnit.MINUTES.toNanos(30)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + + // Primary recovers and callback fires READY, clearing the cooling period + stateChangeCallback.get().run(); + + // Verify immediately routes back to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + + @Test + public void testFallbackWithCredentials() throws Exception { + // Fallback channel should receive custom credentials when provided. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + CallCredentials mockCredentials = mock(CallCredentials.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = + FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); + + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); + + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + verify(mockFallbackChannel).newCall(same(methodDescriptor), optionsCaptor.capture()); + assertEquals(mockCredentials, optionsCaptor.getValue().getCredentials()); + } + + @Test + public void testStateFallbackAfterPrimaryNotReady() { + // If the state-change callback signals primary is not ready for 10+ seconds, + // the next newCall() should route to fallback. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // IDLE for constructor registration, TRANSIENT_FAILURE when callback fires, + // TRANSIENT_FAILURE for re-registration after the callback. + when(mockChannel.getState(false)) + .thenReturn( + ConnectivityState.IDLE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.TRANSIENT_FAILURE); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer. + stateChangeCallback.get().run(); + + // Within 10 seconds: grace period not elapsed, routes to primary. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: routes to fallback. + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + } + + @Test + public void testIdleStateNotTreatedAsFallback() { + // IDLE is a normal healthy state (channel is not actively connected but will reconnect on + // demand). It must NOT start the not-ready timer or trigger state-based fallback, even after + // more than 10 seconds. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // Primary stays IDLE the entire time (constructor registration + all state checks). + when(mockChannel.getState(false)).thenReturn(ConnectivityState.IDLE); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Advance well past the 10-second threshold while primary remains IDLE + time.addAndGet(TimeUnit.SECONDS.toNanos(30)); + + // IDLE must not trigger fallback — all calls still route to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + verify(mockFallbackChannel, never()).newCall(any(), any()); + } + + @Test + public void testStateBasedFallbackRecoveryViaCallback() { + // After state-based fallback, recovery to primary is immediate when callback fires with READY. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // getState() calls in order: + // 1. constructor registerPrimaryStateChangeListener() → IDLE + // 2. onPrimaryStateChanged() fires (TRANSIENT_FAILURE) → TRANSIENT_FAILURE + // 3. re-registerPrimaryStateChangeListener() after 1st callback → TRANSIENT_FAILURE + // 4. onPrimaryStateChanged() fires (READY) → READY + // 5. re-registerPrimaryStateChangeListener() after 2nd callback → READY + when(mockChannel.getState(false)) + .thenReturn( + ConnectivityState.IDLE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.READY, + ConnectivityState.READY); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // First callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer at t=0. + stateChangeCallback.get().run(); + + // Within grace period: routes to primary. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: state-based fallback kicks in. + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + + // Second callback fires: primary is now READY, clears all fallback state. + stateChangeCallback.get().run(); + + // Next call recovers to primary immediately. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + + // --- DEADLINE_EXCEEDED tests --- + + @Test + public void testDeadlineExceededWithoutResponseTriggersFallback() throws Exception { + // DEADLINE_EXCEEDED with no response = connection never established. Should failover. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + // Close with DEADLINE_EXCEEDED and no prior onMessage, should trigger failover + captor.getValue().onClose(Status.DEADLINE_EXCEEDED, new Metadata()); + + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + + @Test + public void testDeadlineExceededWithResponseDoesNotTriggerFallback() throws Exception { + // DEADLINE_EXCEEDED after receiving a response, should NOT + // failover. The connection was healthy since at least one response was delivered. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + // Simulate receiving a response before the timeout + captor.getValue().onMessage(new Object()); + // Close with DEADLINE_EXCEEDED after a response, should NOT trigger failover + captor.getValue().onClose(Status.DEADLINE_EXCEEDED, new Metadata()); + + // Next call should still route to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + verify(mockFallbackChannel, never()).newCall(any(), any()); + } + + // --- Concurrency tests --- + + @Test + public void testConcurrentRPCFailuresProduceConsistentFailover() throws Exception { + // Concurrent RPC failures from multiple threads should produce exactly one failover. + // After all threads complete, subsequent calls must consistently route to fallback. + int numThreads = 20; + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + + // Track all primary ClientCalls created so we can fire onClose on each + List> primaryCalls = Collections.synchronizedList(new ArrayList<>()); + when(mockChannel.newCall(any(), any())) + .thenAnswer( + inv -> { + ClientCall call = mock(ClientCall.class); + primaryCalls.add(call); + return call; + }); + when(mockFallbackChannel.newCall(any(), any())).thenAnswer(inv -> mock(ClientCall.class)); + // Ensure state-based fallback does not interfere + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + // Start N calls on primary and capture their listeners + List> listeners = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + ClientCall call = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + } + for (ClientCall primaryCall : primaryCalls) { + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(primaryCall).start(captor.capture(), any()); + listeners.add(captor.getValue()); + } + + // All threads fire UNAVAILABLE simultaneously + CyclicBarrier barrier = new CyclicBarrier(numThreads); + List threads = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + final Listener listener = listeners.get(i); + Thread t = + new Thread( + () -> { + try { + barrier.await(); + listener.onClose(Status.UNAVAILABLE, new Metadata()); + } catch (Exception e) { + Thread.currentThread().interrupt(); + } + }); + t.start(); + threads.add(t); + } + for (Thread t : threads) { + t.join(5000); + } + + // All subsequent calls must consistently route to fallback + int subsequentCalls = 5; + for (int i = 0; i < subsequentCalls; i++) { + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + } + verify(mockFallbackChannel, atLeast(subsequentCalls)).newCall(any(), any()); + } + + @Test + public void testConcurrentNewCallsDuringRPCFailoverAreConsistent() throws Exception { + // Calls made concurrently while RPC failover is triggered must route consistently: + // none should be lost and each must go to either primary (before failover) or + // fallback (after). + int numThreads = 20; + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenAnswer(inv -> mock(ClientCall.class)); + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + // Set up an in-flight primary call whose failure will trigger RPC-based failover. + ClientCall triggerCall = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + triggerCall.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(listenerCaptor.capture(), any()); + ClientCall.Listener wrappedListener = listenerCaptor.getValue(); + + // All threads (failover trigger + newCall callers) start simultaneously via a barrier. + CyclicBarrier barrier = new CyclicBarrier(numThreads + 1); + List> tasks = new ArrayList<>(); + tasks.add( + () -> { + barrier.await(); + wrappedListener.onClose(Status.UNAVAILABLE, new Metadata()); + return null; + }); + for (int i = 0; i < numThreads; i++) { + tasks.add( + () -> { + barrier.await(); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + return null; + }); + } + + ExecutorService executor = Executors.newFixedThreadPool(numThreads + 1); + executor.invokeAll(tasks, 5, TimeUnit.SECONDS); + executor.shutdown(); + + // After concurrent operations, state must be coherent: subsequent calls go to fallback. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java index 20321bbd66c3..4eb31caf3501 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java @@ -31,8 +31,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.IOException; -import java.io.InputStream; import java.util.concurrent.TimeUnit; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; @@ -40,7 +38,6 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor.Marshaller; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.junit.Test; @@ -49,65 +46,17 @@ import org.mockito.ArgumentCaptor; import org.mockito.InOrder; -/** - * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in - * places where a scriptable call is necessary. By default, all methods are noops, and designed to - * be overridden. - */ -class NoopClientCall extends ClientCall { - - /** - * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It - * is designed to be used in places where a scriptable call listener is necessary. By default, all - * methods are noops, and designed to be overridden. - */ - public static class NoopClientCallListener extends ClientCall.Listener {} - - @Override - public void start(ClientCall.Listener listener, Metadata headers) {} - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(String message, Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} -} - @RunWith(JUnit4.class) public class IsolationChannelTest { private Supplier channelSupplier = mock(Supplier.class); - private static class NoopMarshaller implements Marshaller { - - @Override - public InputStream stream(Object o) { - return new InputStream() { - @Override - public int read() throws IOException { - return 0; - } - }; - } - - @Override - public Object parse(InputStream inputStream) { - return null; - } - }; - private MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) - .setRequestMarshaller(new NoopMarshaller()) - .setResponseMarshaller(new NoopMarshaller()) + .setRequestMarshaller(new NoopClientCall.NoopMarshaller()) + .setResponseMarshaller(new NoopClientCall.NoopMarshaller()) .build(); @Test @@ -414,19 +363,18 @@ public void testAwaitTermination() throws Exception { when(mockChannel.shutdown()).thenReturn(mockChannel); when(mockChannel.isTerminated()).thenReturn(false, false, false, true, true); - when(mockChannel.awaitTermination(longThat(l -> l < 2_000_000), eq(TimeUnit.NANOSECONDS))) + when(mockChannel.awaitTermination(longThat(l -> l > 0), eq(TimeUnit.NANOSECONDS))) .thenReturn(false, true); isolationChannel.shutdown(); - assertFalse(isolationChannel.awaitTermination(1, TimeUnit.MILLISECONDS)); - assertTrue(isolationChannel.awaitTermination(1, TimeUnit.MILLISECONDS)); + assertFalse(isolationChannel.awaitTermination(10, TimeUnit.SECONDS)); + assertTrue(isolationChannel.awaitTermination(10, TimeUnit.SECONDS)); assertTrue(isolationChannel.isTerminated()); verify(channelSupplier, times(1)).get(); verify(mockChannel, times(1)).shutdown(); verify(mockChannel, times(5)).isTerminated(); - verify(mockChannel, times(2)) - .awaitTermination(longThat(l -> l < 2_000_000), eq(TimeUnit.NANOSECONDS)); + verify(mockChannel, times(2)).awaitTermination(longThat(l -> l > 0), eq(TimeUnit.NANOSECONDS)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java new file mode 100644 index 000000000000..1f62aa57f57b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import java.io.IOException; +import java.io.InputStream; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor.Marshaller; + +/** + * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in + * places where a scriptable call is necessary. By default, all methods are noops, and designed to + * be overridden. + */ +public class NoopClientCall extends ClientCall { + + /** + * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It + * is designed to be used in places where a scriptable call listener is necessary. By default, all + * methods are noops, and designed to be overridden. + */ + public static class NoopClientCallListener extends ClientCall.Listener {} + + public static class NoopMarshaller implements Marshaller { + + @Override + public InputStream stream(Object o) { + return new InputStream() { + @Override + public int read() throws IOException { + return 0; + } + }; + } + + @Override + public Object parse(InputStream inputStream) { + return null; + } + } + + @Override + public void start(ClientCall.Listener listener, Metadata headers) {} + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} +} diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index a4b3df906dd9..6286b2d67110 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -470,6 +470,8 @@ message GetWorkRequest { optional string project_id = 7; optional int64 max_items = 2 [default = 0xffffffff]; optional int64 max_bytes = 3 [default = 0x7fffffffffffffff]; + repeated string computation_id_filter = 8; + optional string backend_worker_token = 9; reserved 6; }