Skip to content

Commit 67a6b52

Browse files
committed
mcp: better handling for streamable context cancellation
After walking through our handling of streamable client context cancellation (due to encountering a shutdown deadlock), I think I've settled on a more coherent strategy for handling call cancellation: - In our call handler, retire the request if the call exits due to cancellation: the caller will never see the actual result anyway. - In connectSSE, use the actual request context (the same context used in Write) for the client request, so that it terminates when the context is cancelled. Thread through the initialization context for the standalone SSE request. Also, a couple minor improvements: - Use a detached context for the background context of the client connection. We want to preserve context values (see #513), but it is not right to cancel the connection after Connect has already returned, if the context times out. - Don't use Last-Event-ID != "" as the signal for whether the connectSSE call is initial: if the standalone SSE stream disconnects without an event ID, we'll still reconnect it, and don't want to do so without a delay. + tests, updating the streamable client connection test harness to accomodate the new aspects being exercised. Fixes #662
1 parent 6237801 commit 67a6b52

File tree

5 files changed

+268
-66
lines changed

5 files changed

+268
-66
lines changed

internal/jsonrpc2/conn.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
361361
if err := c.write(ctx, call); err != nil {
362362
// Sending failed. We will never get a response, so deliver a fake one if it
363363
// wasn't already retired by the connection breaking.
364-
c.updateInFlight(func(s *inFlightState) {
365-
if s.outgoingCalls[ac.id] == ac {
366-
delete(s.outgoingCalls, ac.id)
367-
ac.retire(&Response{ID: id, Error: err})
368-
} else {
369-
// ac was already retired by the readIncoming goroutine:
370-
// perhaps our write raced with the Read side of the connection breaking.
371-
}
372-
})
364+
c.Retire(ac, err)
373365
}
374366
return ac
375367
}
376368

369+
// Retire stops tracking the call, and reports err as its terminal error.
370+
//
371+
// Retire is safe to call multiple times: if the call is already no longer
372+
// tracked, Retire is a no op.
373+
func (c *Connection) Retire(ac *AsyncCall, err error) {
374+
c.updateInFlight(func(s *inFlightState) {
375+
if s.outgoingCalls[ac.id] == ac {
376+
delete(s.outgoingCalls, ac.id)
377+
ac.retire(&Response{ID: ac.id, Error: err})
378+
} else {
379+
// ac was already retired elsewhere.
380+
}
381+
})
382+
}
383+
377384
// Async, signals that the current jsonrpc2 request may be handled
378385
// asynchronously to subsequent requests, when ctx is the request context.
379386
//
@@ -437,6 +444,9 @@ func (ac *AsyncCall) IsReady() bool {
437444
}
438445

439446
// retire processes the response to the call.
447+
//
448+
// It is an error to call retire more than once: retire is guarded by the
449+
// connection's outgoingCalls map.
440450
func (ac *AsyncCall) retire(response *Response) {
441451
select {
442452
case <-ac.ready:
@@ -450,6 +460,9 @@ func (ac *AsyncCall) retire(response *Response) {
450460

451461
// Await waits for (and decodes) the results of a Call.
452462
// The response will be unmarshaled from JSON into the result.
463+
//
464+
// If the call is cancelled due to context cancellation, the result is
465+
// ctx.Err().
453466
func (ac *AsyncCall) Await(ctx context.Context, result any) error {
454467
select {
455468
case <-ctx.Done():

mcp/streamable.go

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/modelcontextprotocol/go-sdk/auth"
2727
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
28+
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
2829
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2930
)
3031

@@ -1336,12 +1337,17 @@ const (
13361337
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
13371338
// It must be 1.0 or greater if MaxRetries is greater than 0.
13381339
reconnectGrowFactor = 1.5
1339-
// reconnectInitialDelay is the base delay for the first reconnect attempt.
1340-
reconnectInitialDelay = 1 * time.Second
13411340
// reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
13421341
reconnectMaxDelay = 30 * time.Second
13431342
)
13441343

1344+
var (
1345+
// reconnectInitialDelay is the base delay for the first reconnect attempt.
1346+
//
1347+
// Mutable for testing.
1348+
reconnectInitialDelay = 1 * time.Second
1349+
)
1350+
13451351
// Connect implements the [Transport] interface.
13461352
//
13471353
// The resulting [Connection] writes messages via POST requests to the
@@ -1364,7 +1370,10 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
13641370
// Create a new cancellable context that will manage the connection's lifecycle.
13651371
// This is crucial for cleanly shutting down the background SSE listener by
13661372
// cancelling its blocking network operations, which prevents hangs on exit.
1367-
connCtx, cancel := context.WithCancel(ctx)
1373+
//
1374+
// This context should be detached, to decouple the standalone SSE from the
1375+
// call to Connect.
1376+
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
13681377
conn := &streamableClientConn{
13691378
url: t.Endpoint,
13701379
client: client,
@@ -1383,8 +1392,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
13831392
type streamableClientConn struct {
13841393
url string
13851394
client *http.Client
1386-
ctx context.Context
1387-
cancel context.CancelFunc
1395+
ctx context.Context // connection context, detached from Connect
1396+
cancel context.CancelFunc // cancels ctx
13881397
incoming chan jsonrpc.Message
13891398
maxRetries int
13901399
strict bool // from [StreamableClientTransport.strict]
@@ -1447,9 +1456,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
14471456
}
14481457

14491458
func (c *streamableClientConn) connectStandaloneSSE() {
1450-
resp, err := c.connectSSE("", 0)
1459+
resp, err := c.connectSSE(c.ctx, "", 0, true)
14511460
if err != nil {
1452-
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1461+
// If the client didn't cancel the request, and failure breaks the logical
1462+
// session.
1463+
if c.ctx.Err() == nil {
1464+
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1465+
}
14531466
return
14541467
}
14551468

@@ -1481,7 +1494,7 @@ func (c *streamableClientConn) connectStandaloneSSE() {
14811494
c.fail(err)
14821495
return
14831496
}
1484-
go c.handleSSE(summary, resp, true, nil)
1497+
go c.handleSSE(c.ctx, summary, resp, true, nil)
14851498
}
14861499

14871500
// fail handles an asynchronous error while reading.
@@ -1616,7 +1629,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
16161629
forCall = jsonReq
16171630
}
16181631
// TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1619-
go c.handleSSE(requestSummary, resp, false, forCall)
1632+
go c.handleSSE(ctx, requestSummary, resp, false, forCall)
16201633

16211634
default:
16221635
resp.Body.Close()
@@ -1668,15 +1681,15 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
16681681
//
16691682
// If forCall is set, it is the call that initiated the stream, and the
16701683
// stream is complete when we receive its response.
1671-
func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
1684+
func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
16721685
for {
16731686
// Connection was successful. Continue the loop with the new response.
16741687
// TODO: we should set a reasonable limit on the number of times we'll try
16751688
// getting a response for a given request.
16761689
//
16771690
// Eventually, if we don't get the response, we should stop trying and
16781691
// fail the request.
1679-
lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall)
1692+
lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall)
16801693

16811694
// If the connection was closed by the client, we're done.
16821695
if clientClosed {
@@ -1689,12 +1702,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo
16891702
}
16901703

16911704
// The stream was interrupted or ended by the server. Attempt to reconnect.
1692-
newResp, err := c.connectSSE(lastEventID, reconnectDelay)
1705+
newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false)
16931706
if err != nil {
1694-
// All reconnection attempts failed: fail the connection.
1695-
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
1707+
// If the client didn't cancel this request, any failure to execute it
1708+
// breaks the logical MCP session.
1709+
if ctx.Err() == nil {
1710+
// All reconnection attempts failed: fail the connection.
1711+
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
1712+
}
16961713
return
16971714
}
1715+
16981716
resp = newResp
16991717
if err := c.checkResponse(requestSummary, resp); err != nil {
17001718
c.fail(err)
@@ -1731,11 +1749,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
17311749
// incoming channel. It returns the ID of the last processed event and a flag
17321750
// indicating if the connection was closed by the client. If resp is nil, it
17331751
// returns "", false.
1734-
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
1752+
func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
17351753
defer resp.Body.Close()
17361754
for evt, err := range scanEvents(resp.Body) {
17371755
if err != nil {
1738-
// TODO: we should differentiate EOF from other errors here.
1756+
if ctx.Err() != nil {
1757+
return "", 0, true // don't reconnect: client cancelled
1758+
}
17391759
break
17401760
}
17411761

@@ -1768,6 +1788,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
17681788
return "", 0, true
17691789
}
17701790
}
1791+
17711792
case <-c.done:
17721793
// The connection was closed by the client; exit gracefully.
17731794
return "", 0, true
@@ -1777,6 +1798,9 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
17771798
//
17781799
// If the lastEventID is "", the stream is not retryable and we should
17791800
// report a synthetic error for the call.
1801+
//
1802+
// Note that this is different from the cancellation case above, since the
1803+
// caller is still waiting for a response that will never come.
17801804
if lastEventID == "" && forCall != nil {
17811805
errmsg := &jsonrpc2.Response{
17821806
ID: forCall.ID,
@@ -1800,12 +1824,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
18001824
//
18011825
// reconnectDelay is the delay set by the server using the SSE retry field, or
18021826
// 0.
1803-
func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) {
1827+
//
1828+
// If initial is set, this is the initial attempt.
1829+
//
1830+
// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()).
1831+
func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) {
18041832
var finalErr error
1805-
// If lastEventID is set, we've already connected successfully once, so
1806-
// consider that to be the first attempt.
18071833
attempt := 0
1808-
if lastEventID != "" {
1834+
if !initial {
1835+
// We've already connected successfully once, so delay subsequent
1836+
// reconnections. Otherwise, if the server returns 200 but terminates the
1837+
// connection, we'll reconnect as fast as we can, ad infinitum.
1838+
//
1839+
// TODO: we should consider also setting a limit on total attempts for one
1840+
// logical request.
18091841
attempt = 1
18101842
}
18111843
delay := calculateReconnectDelay(attempt)
@@ -1816,16 +1848,14 @@ func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay tim
18161848
select {
18171849
case <-c.done:
18181850
return nil, fmt.Errorf("connection closed by client during reconnect")
1819-
case <-c.ctx.Done():
1851+
1852+
case <-ctx.Done():
18201853
// If the connection context is canceled, the request below will not
18211854
// succeed anyway.
1822-
//
1823-
// TODO(#662): we should not be using the connection context for
1824-
// reconnection: we should instead be using the call context (from
1825-
// Write).
1826-
return nil, fmt.Errorf("connection context closed")
1855+
return nil, ctx.Err()
1856+
18271857
case <-time.After(delay):
1828-
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1858+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
18291859
if err != nil {
18301860
return nil, err
18311861
}

0 commit comments

Comments
 (0)