diff --git a/.fernignore b/.fernignore index f8d1e81..a14a77f 100644 --- a/.fernignore +++ b/.fernignore @@ -19,6 +19,11 @@ src/main/java/com/deepgram/core/ClientOptions.java # Transport abstraction (pluggable transport for SageMaker, etc.) src/main/java/com/deepgram/core/transport/ +# Bug fixes for maxRetries(0) semantics ("connect once, don't retry") and a +# configurable connectionTimeoutMs on ReconnectOptions (was hardcoded 4000ms). +# Pull this back out once the fixes are upstreamed into the Fern generator. +src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java + # Build and project configuration build.gradle settings.gradle diff --git a/AGENTS.md b/AGENTS.md index a9d607d..e1af8d2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,6 +48,7 @@ How to identify: Current temporarily frozen files: - `src/main/java/com/deepgram/core/ClientOptions.java` - preserves release-please version markers and correct SDK header constants that Fern currently overwrites; use the standard `.bak` swap/restore workflow during regen review +- `src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java` - carries bug fixes for `maxRetries(0)` semantics ("connect once, don't retry") and a configurable `connectionTimeoutMs` field (was hardcoded 4000ms), plus an `applyOptionsOverride(...)` hook used by `TransportWebSocketFactory` to apply per-transport reconnect policy; pull this back out once the fixes are upstreamed into the Fern generator. Use the standard `.bak` swap/restore workflow during regen review. ### Prepare repo for regeneration diff --git a/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java b/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java index 0ca455a..e10e5af 100644 --- a/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java +++ b/src/main/java/com/deepgram/core/ReconnectingWebSocketListener.java @@ -25,16 +25,21 @@ * Provides production-ready resilience for WebSocket connections. */ public abstract class ReconnectingWebSocketListener extends WebSocketListener { - private final long minReconnectionDelayMs; + // Option-derived fields are volatile (not final) so {@link #applyOptionsOverride} can rewire them + // after construction — used by {@code TransportWebSocketFactory} to honour + // {@code DeepgramTransportFactory.reconnectOptions()} without editing the generated WS clients. + private volatile long minReconnectionDelayMs; - private final long maxReconnectionDelayMs; + private volatile long maxReconnectionDelayMs; - private final double reconnectionDelayGrowFactor; + private volatile double reconnectionDelayGrowFactor; - private final int maxRetries; + private volatile int maxRetries; private final int maxEnqueuedMessages; + private volatile long connectionTimeoutMs; + private final AtomicInteger retryCount = new AtomicInteger(0); private final AtomicBoolean connectLock = new AtomicBoolean(false); @@ -66,16 +71,44 @@ public ReconnectingWebSocketListener( this.reconnectionDelayGrowFactor = options.reconnectionDelayGrowFactor; this.maxRetries = options.maxRetries; this.maxEnqueuedMessages = options.maxEnqueuedMessages; + this.connectionTimeoutMs = options.connectionTimeoutMs; this.connectionSupplier = connectionSupplier; } + /** + * Replaces the option-derived parameters on this listener at runtime. Used by + * {@code TransportWebSocketFactory} to apply {@code DeepgramTransportFactory.reconnectOptions()} + * without requiring edits to the generated per-resource WebSocket clients. {@code maxEnqueuedMessages} + * is intentionally not overridden — the message queue is sized at construction. + * + *
Thread-safety: option-derived fields are volatile; reads observe the latest write. The + * initial connect() call may have already started before the override lands, so for the very + * first attempt the original options apply; the override takes effect from the next attempt + * onwards. For the SageMaker storm-suppression case ({@code maxRetries(0)}) this is fine + * because the initial attempt's gate ({@code retryCount > maxRetries} with {@code retryCount=0}) + * always passes regardless. + * + * @param options replacement options; {@code null} is a no-op. + */ + public void applyOptionsOverride(ReconnectOptions options) { + if (options == null) { + return; + } + this.minReconnectionDelayMs = options.minReconnectionDelayMs; + this.maxReconnectionDelayMs = options.maxReconnectionDelayMs; + this.reconnectionDelayGrowFactor = options.reconnectionDelayGrowFactor; + this.maxRetries = options.maxRetries; + this.connectionTimeoutMs = options.connectionTimeoutMs; + } + /** * Initiates a WebSocket connection with automatic reconnection enabled. * * Connection behavior: - * - Times out after 4000 milliseconds + * - Times out after {@code ReconnectOptions.connectionTimeoutMs} (default 4000ms) * - Thread-safe via atomic lock (returns immediately if connection in progress) - * - Retry count not incremented for initial connection attempt + * - {@code maxRetries} counts retries only — the initial attempt always proceeds. + * {@code maxRetries(0)} means "connect once, don't retry" (not "refuse to connect"). * * Error handling: * - TimeoutException: Includes retry attempt context @@ -86,18 +119,21 @@ public void connect() { if (!connectLock.compareAndSet(false, true)) { return; } - if (retryCount.get() >= maxRetries) { + // retryCount is incremented inside scheduleReconnect() before re-entering connect(), + // so on the initial call retryCount == 0 and we always proceed. The cap applies to + // retries only — maxRetries(0) blocks retries but allows the initial attempt. + if (retryCount.get() > maxRetries) { connectLock.set(false); return; } try { CompletableFuture extends WebSocket> connectionFuture = CompletableFuture.supplyAsync(connectionSupplier); try { - webSocket = connectionFuture.get(4000, MILLISECONDS); + webSocket = connectionFuture.get(connectionTimeoutMs, MILLISECONDS); } catch (TimeoutException e) { connectionFuture.cancel(true); TimeoutException timeoutError = - new TimeoutException("WebSocket connection timeout after " + 4000 + " milliseconds" + new TimeoutException("WebSocket connection timeout after " + connectionTimeoutMs + " milliseconds" + (retryCount.get() > 0 ? " (retry attempt #" + retryCount.get() : " (initial connection attempt)")); @@ -399,12 +435,15 @@ public static final class ReconnectOptions { public final int maxEnqueuedMessages; + public final long connectionTimeoutMs; + private ReconnectOptions(Builder builder) { this.minReconnectionDelayMs = builder.minReconnectionDelayMs; this.maxReconnectionDelayMs = builder.maxReconnectionDelayMs; this.reconnectionDelayGrowFactor = builder.reconnectionDelayGrowFactor; this.maxRetries = builder.maxRetries; this.maxEnqueuedMessages = builder.maxEnqueuedMessages; + this.connectionTimeoutMs = builder.connectionTimeoutMs; } public static Builder builder() { @@ -422,12 +461,15 @@ public static final class Builder { private int maxEnqueuedMessages; + private long connectionTimeoutMs; + public Builder() { this.minReconnectionDelayMs = 1000; this.maxReconnectionDelayMs = 10000; this.reconnectionDelayGrowFactor = 1.3; this.maxRetries = 2147483647; this.maxEnqueuedMessages = 1000; + this.connectionTimeoutMs = 4000; } public Builder minReconnectionDelayMs(long minReconnectionDelayMs) { @@ -455,6 +497,16 @@ public Builder maxEnqueuedMessages(int maxEnqueuedMessages) { return this; } + /** + * Sets the per-attempt connection timeout in milliseconds. Defaults to {@code 4000}. + * Each call to {@link ReconnectingWebSocketListener#connect()} will wait at most + * this long for the underlying WebSocket factory to produce a connected socket. + */ + public Builder connectionTimeoutMs(long connectionTimeoutMs) { + this.connectionTimeoutMs = connectionTimeoutMs; + return this; + } + /** * Builds the ReconnectOptions with validation. * @@ -463,6 +515,7 @@ public Builder maxEnqueuedMessages(int maxEnqueuedMessages) { * - minReconnectionDelayMs <= maxReconnectionDelayMs * - reconnectionDelayGrowFactor >= 1.0 * - maxRetries and maxEnqueuedMessages are non-negative + * - connectionTimeoutMs is positive * * @return The validated ReconnectOptions instance * @throws IllegalArgumentException if configuration is invalid @@ -487,6 +540,9 @@ public ReconnectOptions build() { if (maxEnqueuedMessages < 0) { throw new IllegalArgumentException("maxEnqueuedMessages must be non-negative"); } + if (connectionTimeoutMs <= 0) { + throw new IllegalArgumentException("connectionTimeoutMs must be positive"); + } return new ReconnectOptions(this); } } diff --git a/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java b/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java index f42dd6a..54d8f65 100644 --- a/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java +++ b/src/main/java/com/deepgram/core/transport/DeepgramTransportFactory.java @@ -1,5 +1,6 @@ package com.deepgram.core.transport; +import com.deepgram.core.ReconnectingWebSocketListener; import java.util.Map; /** @@ -19,7 +20,6 @@ *
When a transport factory is set, all WebSocket clients (Listen, Speak, Agent) will use it
* instead of the default OkHttp WebSocket connection.
*/
-@FunctionalInterface
public interface DeepgramTransportFactory {
/**
@@ -31,4 +31,19 @@ public interface DeepgramTransportFactory {
* @return a connected or connecting transport instance
*/
DeepgramTransport create(String url, Map Plugins that own their own connection lifecycle and retry/backoff (e.g. SageMaker bidi
+ * streaming) should return {@code ReconnectOptions.builder().maxRetries(0).build()} so the
+ * SDK's wrapper-level reconnect doesn't compound their internal retries into a storm.
+ *
+ * @return reconnect options to apply, or {@code null} for SDK defaults
+ */
+ default ReconnectingWebSocketListener.ReconnectOptions reconnectOptions() {
+ return null;
+ }
}
diff --git a/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java b/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java
index 8df13d1..93c73ad 100644
--- a/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java
+++ b/src/main/java/com/deepgram/core/transport/TransportWebSocketFactory.java
@@ -1,5 +1,6 @@
package com.deepgram.core.transport;
+import com.deepgram.core.ReconnectingWebSocketListener;
import com.deepgram.core.WebSocketFactory;
import java.util.LinkedHashMap;
import java.util.Map;
@@ -31,6 +32,13 @@ public TransportWebSocketFactory(DeepgramTransportFactory transportFactory) {
@Override
public WebSocket create(Request request, WebSocketListener listener) {
+ // Apply the plugin-declared reconnect policy to the SDK's wrapping listener. Plugins that
+ // own their own retry/backoff (SageMaker) return maxRetries(0) here so the wrapper-level
+ // reconnect doesn't compound their internal retries into a storm.
+ if (listener instanceof ReconnectingWebSocketListener) {
+ ((ReconnectingWebSocketListener) listener).applyOptionsOverride(transportFactory.reconnectOptions());
+ }
+
String url = request.url().toString();
// Restore wss:// scheme — OkHttp's HttpUrl normalizes to https://
if (url.startsWith("https://")) {
diff --git a/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java b/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java
new file mode 100644
index 0000000..da8001e
--- /dev/null
+++ b/src/test/java/com/deepgram/core/ReconnectingWebSocketListenerTest.java
@@ -0,0 +1,190 @@
+package com.deepgram.core;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import com.deepgram.core.ReconnectingWebSocketListener.ReconnectOptions;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
+import okhttp3.Response;
+import okhttp3.WebSocket;
+import okio.ByteString;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.Test;
+
+/** Unit tests for {@link ReconnectingWebSocketListener} bug fixes. */
+class ReconnectingWebSocketListenerTest {
+
+ /**
+ * Test fixture: counts how many times the supplier was invoked, and whether each call
+ * "succeeded" by simulating a fake WebSocket. Failures are simulated by throwing.
+ */
+ private static final class CountingSupplier implements Supplier