Skip to content

Commit fef3bdb

Browse files
committed
refactor: improve error handling and cleanup HTTP client transports
- Add proper exception handling with CompletableFuture.exceptionallyCompose for async HTTP operations - Add test for specific exception type handling in resiliency tests This change makes the HTTP client transports more robust by ensuring exceptions are properly propagated. Signed-off-by: Christian Tzolov <[email protected]>
1 parent ae387a6 commit fef3bdb

File tree

3 files changed

+100
-122
lines changed

3 files changed

+100
-122
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.net.http.HttpRequest;
1010
import java.net.http.HttpResponse;
1111
import java.time.Duration;
12+
import java.util.concurrent.CompletableFuture;
1213
import java.util.concurrent.atomic.AtomicReference;
1314
import java.util.function.Consumer;
1415
import java.util.function.Function;
@@ -332,10 +333,13 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
332333
.GET()
333334
.build();
334335

335-
Disposable connection = Flux
336-
.<ResponseEvent>create(sseSink -> this.httpClient.sendAsync(request,
337-
responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)))
338-
.flatMap(responseEvent -> {
336+
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient
337+
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
338+
.exceptionallyCompose(e -> {
339+
logger.warn("Error sending message", e);
340+
sseSink.error(e);
341+
return CompletableFuture.failedFuture(e);
342+
})).flatMap(responseEvent -> {
339343
if (isClosing) {
340344
return Mono.empty();
341345
}
@@ -375,24 +379,19 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
375379
return Flux.<McpSchema.JSONRPCMessage>error(
376380
new RuntimeException("Failed to send message: " + responseEvent));
377381

378-
})
379-
.flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage)))
380-
.onErrorResume(t -> {
382+
}).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> {
381383
if (!isClosing) {
382384
logger.error("SSE connection error", t);
383385
sink.error(t);
384386
}
385387
return Mono.empty();
386388

387-
})
388-
.doFinally(s -> {
389+
}).doFinally(s -> {
389390
Disposable ref = this.sseSubscription.getAndSet(null);
390391
if (ref != null && !ref.isDisposed()) {
391392
ref.dispose();
392393
}
393-
})
394-
.contextWrite(sink.contextView())
395-
.subscribe();
394+
}).contextWrite(sink.contextView()).subscribe();
396395

397396
this.sseSubscription.set(connection);
398397
});
@@ -460,7 +459,11 @@ private Mono<HttpResponse<Void>> sendHttpPost(final String endpoint, final Strin
460459
.POST(HttpRequest.BodyPublishers.ofString(body))
461460
.build();
462461

463-
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()));
462+
return Mono.fromFuture(
463+
httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).exceptionallyCompose(e -> {
464+
logger.warn("Error sending message", e);
465+
return CompletableFuture.failedFuture(e);
466+
}));
464467
}
465468

466469
/**

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java

Lines changed: 65 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.time.Duration;
1414
import java.util.List;
1515
import java.util.Optional;
16+
import java.util.concurrent.CompletableFuture;
1617
import java.util.concurrent.atomic.AtomicReference;
1718
import java.util.function.Consumer;
1819
import java.util.function.Function;
@@ -160,9 +161,12 @@ private Publisher<Void> createDelete(String sessionId) {
160161
.DELETE()
161162
.build();
162163

163-
return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()))
164-
.doOnError(e -> logger.warn("Got error when closing transport", e))
165-
.then();
164+
return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
165+
.exceptionallyCompose(e -> {
166+
logger.warn("Error sending message", e);
167+
168+
return CompletableFuture.failedFuture(e);
169+
})).doOnError(e -> logger.warn("Got error when closing transport", e)).then();
166170
});
167171
}
168172

@@ -227,86 +231,63 @@ private Mono<Disposable> reconnect(McpTransportStream<Disposable> stream) {
227231
.GET()
228232
.build();
229233

230-
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient.sendAsync(request,
231-
responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
232-
// .whenComplete((response, throwable) -> {
233-
// if (throwable != null) {
234-
// sseSink.error(throwable);
235-
// } else {
236-
// int status = response.statusCode();
237-
// if (status == METHOD_NOT_ALLOWED) { // NotAllowed
238-
// logger.debug("The server does not support SSE streams, using
239-
// request-response mode.");
240-
// sseSink.complete();
241-
// } else if (status == NOT_FOUND || status == BAD_REQUEST) { // NotFound
242-
// String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
243-
// sseSink.error(new McpTransportSessionNotFoundException(
244-
// "Session not found for session ID: " + sessionIdRepresentation));
245-
// } else if (!isEventStream(response)) {
246-
// String message = "Failed to connect to SSE stream. HTTP " +
247-
// response.statusCode();
248-
// if (response.body() != null) {
249-
// message += ": " + response.body();
250-
// }
251-
// logger.info("Opening an SSE stream failed. This can be safely ignored." +
252-
// message);
253-
// sseSink.error(new RuntimeException(message));
254-
// }
255-
// // If status is OK, the lineSubscriber will handle the
256-
// // stream
257-
// logger.debug("Established SSE stream via GET");
258-
// }
259-
// })
260-
).flatMap(responseEvent -> {
261-
int statusCode = responseEvent.responseInfo().statusCode();
262-
263-
if (statusCode >= 200 && statusCode < 300) {
264-
265-
if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
266-
try {
267-
// We don't support batching ATM and probably won't since the
268-
// next version considers removing it.
269-
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper,
270-
responseEvent.sseEvent().data());
271-
272-
Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> idWithMessages = Tuples
273-
.of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message));
274-
275-
McpTransportStream<Disposable> sessionStream = stream != null ? stream
276-
: new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect);
277-
logger.debug("Connected stream {}", sessionStream.streamId());
278-
279-
return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages)));
280-
281-
}
282-
catch (IOException ioException) {
283-
return Flux.<McpSchema.JSONRPCMessage>error(
284-
new McpError("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data()));
234+
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient
235+
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
236+
.exceptionallyCompose(e -> {
237+
logger.warn("Error sending message", e);
238+
sseSink.error(e);
239+
return CompletableFuture.failedFuture(e);
240+
})).flatMap(responseEvent -> {
241+
int statusCode = responseEvent.responseInfo().statusCode();
242+
243+
if (statusCode >= 200 && statusCode < 300) {
244+
245+
if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
246+
try {
247+
// We don't support batching ATM and probably won't since
248+
// the
249+
// next version considers removing it.
250+
McpSchema.JSONRPCMessage message = McpSchema
251+
.deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data());
252+
253+
Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> idWithMessages = Tuples
254+
.of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message));
255+
256+
McpTransportStream<Disposable> sessionStream = stream != null ? stream
257+
: new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect);
258+
logger.debug("Connected stream {}", sessionStream.streamId());
259+
260+
return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages)));
261+
262+
}
263+
catch (IOException ioException) {
264+
return Flux.<McpSchema.JSONRPCMessage>error(new McpError(
265+
"Error parsing JSON-RPC message: " + responseEvent.sseEvent().data()));
266+
}
285267
}
286268
}
287-
}
288-
else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed
289-
logger.debug("The server does not support SSE streams, using request-response mode.");
290-
return Flux.empty();
291-
}
292-
else if (statusCode == NOT_FOUND) {
293-
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
294-
McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException(
295-
"Session not found for session ID: " + sessionIdRepresentation);
296-
return Flux.<McpSchema.JSONRPCMessage>error(exception);
297-
}
298-
else if (statusCode == BAD_REQUEST) {
299-
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
300-
McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException(
301-
"Session not found for session ID: " + sessionIdRepresentation);
302-
return Flux.<McpSchema.JSONRPCMessage>error(exception);
303-
}
269+
else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed
270+
logger.debug("The server does not support SSE streams, using request-response mode.");
271+
return Flux.empty();
272+
}
273+
else if (statusCode == NOT_FOUND) {
274+
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
275+
McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException(
276+
"Session not found for session ID: " + sessionIdRepresentation);
277+
return Flux.<McpSchema.JSONRPCMessage>error(exception);
278+
}
279+
else if (statusCode == BAD_REQUEST) {
280+
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
281+
McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException(
282+
"Session not found for session ID: " + sessionIdRepresentation);
283+
return Flux.<McpSchema.JSONRPCMessage>error(exception);
284+
}
304285

305-
return Flux.<McpSchema.JSONRPCMessage>error(
306-
new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event()));
286+
return Flux.<McpSchema.JSONRPCMessage>error(
287+
new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event()));
307288

308-
}).<McpSchema
309-
.JSONRPCMessage>flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage)))
289+
}).<McpSchema
290+
.JSONRPCMessage>flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage)))
310291
.onErrorComplete(t -> {
311292
this.handleException(t);
312293
return true;
@@ -327,13 +308,6 @@ else if (statusCode == BAD_REQUEST) {
327308

328309
}
329310

330-
// private static boolean isEventStream(HttpResponse<Void> response) {
331-
// String contentType =
332-
// response.headers().firstValue("Content-Type").orElse("").toLowerCase();
333-
// return response.statusCode() >= 200 && response.statusCode() < 300 &&
334-
// contentType.contains(TEXT_EVENT_STREAM);
335-
// }
336-
337311
private BodyHandler<Void> toSendMessageBodySubscriber(FluxSink<ResponseEvent> sink) {
338312

339313
BodyHandler<Void> responseBodyHandler = responseInfo -> {
@@ -395,29 +369,11 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage sendMessage) {
395369

396370
// Create the async request with proper body subscriber selection
397371
Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink))
398-
// .whenComplete((res, e) -> {
399-
// if (e != null) {
400-
// logger.warn("Error sending message", e);
401-
// responseEventSink.error(e);
402-
// } else if (res.statusCode() == NOT_FOUND) {
403-
// String sessionIdRepresentation =
404-
// sessionIdOrPlaceholder(transportSession);
405-
// McpTransportSessionNotFoundException exception = new
406-
// McpTransportSessionNotFoundException(
407-
// "Session not found for session ID: " + sessionIdRepresentation);
408-
// this.handleException(exception);
409-
// responseEventSink.error(exception);
410-
// } else if (res.statusCode() == BAD_REQUEST) {
411-
// System.out.println("BAD_REQUEST");
412-
// } else {
413-
// logger.debug("whenComplete complete: resp: {}, reqBode: {}", request,
414-
// jsonBody);
415-
// }
416-
// })).doOnSubscribe(sub -> {
417-
// logger.debug("OnSubscribe: {}, Sending message to server: {}", sub,
418-
// jsonBody);
419-
// }
420-
).subscribe();
372+
.exceptionallyCompose(e -> {
373+
logger.warn("Error sending message", e);
374+
responseEventSink.error(e);
375+
return CompletableFuture.failedFuture(e);
376+
})).subscribe();
421377

422378
}).flatMap(responseEvent -> {
423379
if (transportSession.markInitialized(

mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44

55
package io.modelcontextprotocol.client;
66

7+
import java.util.concurrent.CompletionException;
8+
9+
import org.junit.jupiter.api.Test;
710
import org.junit.jupiter.api.Timeout;
811

912
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
1013
import io.modelcontextprotocol.spec.McpClientTransport;
14+
import reactor.test.StepVerifier;
1115

1216
@Timeout(15)
1317
public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests {
@@ -17,4 +21,19 @@ protected McpClientTransport createMcpTransport() {
1721
return HttpClientStreamableHttpTransport.builder(host).build();
1822
}
1923

24+
@Test
25+
void testPingWithEaxctExceptionType() {
26+
withClient(createMcpTransport(), mcpAsyncClient -> {
27+
StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
28+
29+
disconnect();
30+
31+
StepVerifier.create(mcpAsyncClient.ping()).expectError(CompletionException.class).verify();
32+
33+
reconnect();
34+
35+
StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete();
36+
});
37+
}
38+
2039
}

0 commit comments

Comments
 (0)