diff --git a/.fernignore b/.fernignore index f8d1e81..a14a77f 100644 --- a/.fernignore +++ b/.fernignore @@ -19,6 +19,11 @@ src/main/java/com/deepgram/core/ClientOptions.java # Transport abstraction (pluggable transport for SageMaker, etc.) src/main/java/com/deepgram/core/transport/ +# Bug fixes for maxRetries(0) semantics ("connect once, don't retry") and a +# configurable connectionTimeoutMs on ReconnectOptions (was hardcoded 4000ms). +# Pull this back out once the fixes are upstreamed into the Fern generator. +src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java + # Build and project configuration build.gradle settings.gradle diff --git a/AGENTS.md b/AGENTS.md index a9d607d..e1af8d2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,6 +48,7 @@ How to identify: Current temporarily frozen files: - `src/main/java/com/deepgram/core/ClientOptions.java` - preserves release-please version markers and correct SDK header constants that Fern currently overwrites; use the standard `.bak` swap/restore workflow during regen review +- `src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java` - carries bug fixes for `maxRetries(0)` semantics ("connect once, don't retry") and a configurable `connectionTimeoutMs` field (was hardcoded 4000ms), plus an `applyOptionsOverride(...)` hook used by `TransportWebSocketFactory` to apply per-transport reconnect policy; pull this back out once the fixes are upstreamed into the Fern generator. Use the standard `.bak` swap/restore workflow during regen review. ### Prepare repo for regeneration diff --git a/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java b/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java index 0ca455a..e10e5af 100644 --- a/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java +++ b/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java @@ -25,16 +25,21 @@ * Provides production-ready resilience for WebSocket connections. */ public abstract class ReconnectingWebSocketListener extends WebSocketListener { - private final long minReconnectionDelayMs; + // Option-derived fields are volatile (not final) so {@link #applyOptionsOverride} can rewire them + // after construction — used by {@code TransportWebSocketFactory} to honour + // {@code DeepgramTransportFactory.reconnectOptions()} without editing the generated WS clients. + private volatile long minReconnectionDelayMs; - private final long maxReconnectionDelayMs; + private volatile long maxReconnectionDelayMs; - private final double reconnectionDelayGrowFactor; + private volatile double reconnectionDelayGrowFactor; - private final int maxRetries; + private volatile int maxRetries; private final int maxEnqueuedMessages; + private volatile long connectionTimeoutMs; + private final AtomicInteger retryCount = new AtomicInteger(0); private final AtomicBoolean connectLock = new AtomicBoolean(false); @@ -66,16 +71,44 @@ public ReconnectingWebSocketListener( this.reconnectionDelayGrowFactor = options.reconnectionDelayGrowFactor; this.maxRetries = options.maxRetries; this.maxEnqueuedMessages = options.maxEnqueuedMessages; + this.connectionTimeoutMs = options.connectionTimeoutMs; this.connectionSupplier = connectionSupplier; } + /** + * Replaces the option-derived parameters on this listener at runtime. Used by + * {@code TransportWebSocketFactory} to apply {@code DeepgramTransportFactory.reconnectOptions()} + * without requiring edits to the generated per-resource WebSocket clients. {@code maxEnqueuedMessages} + * is intentionally not overridden — the message queue is sized at construction. + * + *

Thread-safety: option-derived fields are volatile; reads observe the latest write. The + * initial connect() call may have already started before the override lands, so for the very + * first attempt the original options apply; the override takes effect from the next attempt + * onwards. For the SageMaker storm-suppression case ({@code maxRetries(0)}) this is fine + * because the initial attempt's gate ({@code retryCount > maxRetries} with {@code retryCount=0}) + * always passes regardless. + * + * @param options replacement options; {@code null} is a no-op. + */ + public void applyOptionsOverride(ReconnectOptions options) { + if (options == null) { + return; + } + this.minReconnectionDelayMs = options.minReconnectionDelayMs; + this.maxReconnectionDelayMs = options.maxReconnectionDelayMs; + this.reconnectionDelayGrowFactor = options.reconnectionDelayGrowFactor; + this.maxRetries = options.maxRetries; + this.connectionTimeoutMs = options.connectionTimeoutMs; + } + /** * Initiates a WebSocket connection with automatic reconnection enabled. * * Connection behavior: - * - Times out after 4000 milliseconds + * - Times out after {@code ReconnectOptions.connectionTimeoutMs} (default 4000ms) * - Thread-safe via atomic lock (returns immediately if connection in progress) - * - Retry count not incremented for initial connection attempt + * - {@code maxRetries} counts retries only — the initial attempt always proceeds. + * {@code maxRetries(0)} means "connect once, don't retry" (not "refuse to connect"). * * Error handling: * - TimeoutException: Includes retry attempt context @@ -86,18 +119,21 @@ public void connect() { if (!connectLock.compareAndSet(false, true)) { return; } - if (retryCount.get() >= maxRetries) { + // retryCount is incremented inside scheduleReconnect() before re-entering connect(), + // so on the initial call retryCount == 0 and we always proceed. The cap applies to + // retries only — maxRetries(0) blocks retries but allows the initial attempt. + if (retryCount.get() > maxRetries) { connectLock.set(false); return; } try { CompletableFuture connectionFuture = CompletableFuture.supplyAsync(connectionSupplier); try { - webSocket = connectionFuture.get(4000, MILLISECONDS); + webSocket = connectionFuture.get(connectionTimeoutMs, MILLISECONDS); } catch (TimeoutException e) { connectionFuture.cancel(true); TimeoutException timeoutError = - new TimeoutException("WebSocket connection timeout after " + 4000 + " milliseconds" + new TimeoutException("WebSocket connection timeout after " + connectionTimeoutMs + " milliseconds" + (retryCount.get() > 0 ? " (retry attempt #" + retryCount.get() : " (initial connection attempt)")); @@ -399,12 +435,15 @@ public static final class ReconnectOptions { public final int maxEnqueuedMessages; + public final long connectionTimeoutMs; + private ReconnectOptions(Builder builder) { this.minReconnectionDelayMs = builder.minReconnectionDelayMs; this.maxReconnectionDelayMs = builder.maxReconnectionDelayMs; this.reconnectionDelayGrowFactor = builder.reconnectionDelayGrowFactor; this.maxRetries = builder.maxRetries; this.maxEnqueuedMessages = builder.maxEnqueuedMessages; + this.connectionTimeoutMs = builder.connectionTimeoutMs; } public static Builder builder() { @@ -422,12 +461,15 @@ public static final class Builder { private int maxEnqueuedMessages; + private long connectionTimeoutMs; + public Builder() { this.minReconnectionDelayMs = 1000; this.maxReconnectionDelayMs = 10000; this.reconnectionDelayGrowFactor = 1.3; this.maxRetries = 2147483647; this.maxEnqueuedMessages = 1000; + this.connectionTimeoutMs = 4000; } public Builder minReconnectionDelayMs(long minReconnectionDelayMs) { @@ -455,6 +497,16 @@ public Builder maxEnqueuedMessages(int maxEnqueuedMessages) { return this; } + /** + * Sets the per-attempt connection timeout in milliseconds. Defaults to {@code 4000}. + * Each call to {@link ReconnectingWebSocketListener#connect()} will wait at most + * this long for the underlying WebSocket factory to produce a connected socket. + */ + public Builder connectionTimeoutMs(long connectionTimeoutMs) { + this.connectionTimeoutMs = connectionTimeoutMs; + return this; + } + /** * Builds the ReconnectOptions with validation. * @@ -463,6 +515,7 @@ public Builder maxEnqueuedMessages(int maxEnqueuedMessages) { * - minReconnectionDelayMs <= maxReconnectionDelayMs * - reconnectionDelayGrowFactor >= 1.0 * - maxRetries and maxEnqueuedMessages are non-negative + * - connectionTimeoutMs is positive * * @return The validated ReconnectOptions instance * @throws IllegalArgumentException if configuration is invalid @@ -487,6 +540,9 @@ public ReconnectOptions build() { if (maxEnqueuedMessages < 0) { throw new IllegalArgumentException("maxEnqueuedMessages must be non-negative"); } + if (connectionTimeoutMs <= 0) { + throw new IllegalArgumentException("connectionTimeoutMs must be positive"); + } return new ReconnectOptions(this); } } diff --git a/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java b/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java index f42dd6a..54d8f65 100644 --- a/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java +++ b/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java @@ -1,5 +1,6 @@ package com.deepgram.core.transport; +import com.deepgram.core.ReconnectingWebSocketListener; import java.util.Map; /** @@ -19,7 +20,6 @@ *

When a transport factory is set, all WebSocket clients (Listen, Speak, Agent) will use it * instead of the default OkHttp WebSocket connection. */ -@FunctionalInterface public interface DeepgramTransportFactory { /** @@ -31,4 +31,19 @@ public interface DeepgramTransportFactory { * @return a connected or connecting transport instance */ DeepgramTransport create(String url, Map headers); + + /** + * Reconnect policy the SDK should apply when wrapping connections produced by this factory. + * Returning {@code null} (the default) leaves the SDK's {@link ReconnectingWebSocketListener} + * defaults in place. + * + *

Plugins that own their own connection lifecycle and retry/backoff (e.g. SageMaker bidi + * streaming) should return {@code ReconnectOptions.builder().maxRetries(0).build()} so the + * SDK's wrapper-level reconnect doesn't compound their internal retries into a storm. + * + * @return reconnect options to apply, or {@code null} for SDK defaults + */ + default ReconnectingWebSocketListener.ReconnectOptions reconnectOptions() { + return null; + } } diff --git a/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java b/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java index 8df13d1..93c73ad 100644 --- a/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java +++ b/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java @@ -1,5 +1,6 @@ package com.deepgram.core.transport; +import com.deepgram.core.ReconnectingWebSocketListener; import com.deepgram.core.WebSocketFactory; import java.util.LinkedHashMap; import java.util.Map; @@ -31,6 +32,13 @@ public TransportWebSocketFactory(DeepgramTransportFactory transportFactory) { @Override public WebSocket create(Request request, WebSocketListener listener) { + // Apply the plugin-declared reconnect policy to the SDK's wrapping listener. Plugins that + // own their own retry/backoff (SageMaker) return maxRetries(0) here so the wrapper-level + // reconnect doesn't compound their internal retries into a storm. + if (listener instanceof ReconnectingWebSocketListener) { + ((ReconnectingWebSocketListener) listener).applyOptionsOverride(transportFactory.reconnectOptions()); + } + String url = request.url().toString(); // Restore wss:// scheme — OkHttp's HttpUrl normalizes to https:// if (url.startsWith("https://")) { diff --git a/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java b/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java new file mode 100644 index 0000000..da8001e --- /dev/null +++ b/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java @@ -0,0 +1,190 @@ +package com.deepgram.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.deepgram.core.ReconnectingWebSocketListener.ReconnectOptions; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import okhttp3.Response; +import okhttp3.WebSocket; +import okio.ByteString; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link ReconnectingWebSocketListener} bug fixes. */ +class ReconnectingWebSocketListenerTest { + + /** + * Test fixture: counts how many times the supplier was invoked, and whether each call + * "succeeded" by simulating a fake WebSocket. Failures are simulated by throwing. + */ + private static final class CountingSupplier implements Supplier { + private final AtomicInteger calls = new AtomicInteger(0); + private final boolean shouldFail; + + CountingSupplier(boolean shouldFail) { + this.shouldFail = shouldFail; + } + + @Override + public WebSocket get() { + calls.incrementAndGet(); + if (shouldFail) { + throw new RuntimeException("simulated connect failure"); + } + return new FakeWebSocket(); + } + } + + private static final class FakeWebSocket implements WebSocket { + @Override + public okhttp3.Request request() { + return new okhttp3.Request.Builder().url("ws://localhost/").build(); + } + + @Override + public long queueSize() { + return 0; + } + + @Override + public boolean send(String text) { + return true; + } + + @Override + public boolean send(ByteString bytes) { + return true; + } + + @Override + public boolean close(int code, String reason) { + return true; + } + + @Override + public void cancel() {} + } + + /** Concrete listener that records callback invocations for assertions. */ + private static final class TestListener extends ReconnectingWebSocketListener { + final AtomicInteger failures = new AtomicInteger(0); + + TestListener(ReconnectOptions options, Supplier supplier) { + super(options, supplier); + } + + @Override + protected void onWebSocketOpen(WebSocket webSocket, Response response) {} + + @Override + protected void onWebSocketMessage(WebSocket webSocket, String text) {} + + @Override + protected void onWebSocketBinaryMessage(WebSocket webSocket, ByteString bytes) {} + + @Override + protected void onWebSocketFailure(WebSocket webSocket, Throwable t, Response response) { + failures.incrementAndGet(); + } + + @Override + protected void onWebSocketClosed(WebSocket webSocket, int code, String reason) {} + } + + @Nested + @DisplayName("ReconnectOptions builder") + class BuilderTests { + @Test + @DisplayName("connectionTimeoutMs defaults to 4000") + void connectionTimeoutDefaultsTo4000() { + ReconnectOptions opts = ReconnectOptions.builder().build(); + assertThat(opts.connectionTimeoutMs).isEqualTo(4000L); + } + + @Test + @DisplayName("connectionTimeoutMs can be customized") + void connectionTimeoutCustomizable() { + ReconnectOptions opts = + ReconnectOptions.builder().connectionTimeoutMs(15_000L).build(); + assertThat(opts.connectionTimeoutMs).isEqualTo(15_000L); + } + + @Test + @DisplayName("connectionTimeoutMs must be positive") + void connectionTimeoutValidatedPositive() { + assertThatThrownBy(() -> ReconnectOptions.builder().connectionTimeoutMs(0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("connectionTimeoutMs"); + assertThatThrownBy(() -> ReconnectOptions.builder().connectionTimeoutMs(-1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("connectionTimeoutMs"); + } + + @Test + @DisplayName("maxRetries(0) is allowed (no retries, but initial attempt still allowed)") + void maxRetriesZeroAllowed() { + ReconnectOptions opts = ReconnectOptions.builder().maxRetries(0).build(); + assertThat(opts.maxRetries).isZero(); + } + } + + @Nested + @DisplayName("connect() with maxRetries(0)") + class MaxRetriesZeroTests { + @Test + @DisplayName("allows the initial attempt to proceed (regression: previously refused to connect)") + void initialAttemptProceedsWhenMaxRetriesIsZero() { + CountingSupplier supplier = new CountingSupplier(false); + ReconnectOptions opts = ReconnectOptions.builder().maxRetries(0).build(); + TestListener listener = new TestListener(opts, supplier); + + listener.connect(); + + assertThat(supplier.calls.get()) + .as("initial connect attempt must run even with maxRetries(0)") + .isEqualTo(1); + } + } + + @Nested + @DisplayName("applyOptionsOverride") + class ApplyOverrideTests { + @Test + @DisplayName("null override is a no-op") + void nullOverrideIsNoop() { + CountingSupplier supplier = new CountingSupplier(false); + ReconnectOptions opts = + ReconnectOptions.builder().maxRetries(7).build(); + TestListener listener = new TestListener(opts, supplier); + + listener.applyOptionsOverride(null); + listener.connect(); + + assertThat(supplier.calls.get()).isEqualTo(1); + } + + @Test + @DisplayName("override before connect() applies maxRetries(0) on subsequent retry attempts") + void overrideAppliesToRetryGate() { + // Arrange: large initial maxRetries so the listener would normally retry forever. + CountingSupplier supplier = new CountingSupplier(true /* always fail */); + ReconnectOptions opts = + ReconnectOptions.builder().maxRetries(Integer.MAX_VALUE).build(); + TestListener listener = new TestListener(opts, supplier); + + // Apply the override BEFORE the first connect call. The override survives the connect() + // call and gates any scheduled retries. + listener.applyOptionsOverride( + ReconnectOptions.builder().maxRetries(0).build()); + + listener.connect(); + // The initial attempt still runs (gate is `retryCount > maxRetries`, retryCount=0 → false). + assertThat(supplier.calls.get()).isEqualTo(1); + + listener.disconnect(); + } + } +}