Skip to content

Commit 272e0cd

Browse files
authored
mcp: better handling for streamable context cancellation (#677)
After walking through our handling of streamable client context cancellation (due to encountering a shutdown deadlock), I think I've settled on a more coherent strategy for handling call cancellation: - In our call handler, retire the request if the call exits due to cancellation: the caller will never see the actual result anyway. - In connectSSE, use the actual request context (the same context used in Write) for the client request, so that it terminates when the context is cancelled. Thread through the initialization context for the standalone SSE request. Also, a couple minor improvements: - Use a detached context for the background context of the client connection. We want to preserve context values (see #513), but it is not right to cancel the connection after Connect has already returned, if the context times out. - Don't use Last-Event-ID != "" as the signal for whether the connectSSE call is initial: if the standalone SSE stream disconnects without an event ID, we'll still reconnect it, and don't want to do so without a delay. + tests, updating the streamable client connection test harness to accomodate the new aspects being exercised. Fixes #662
1 parent 8adc1e3 commit 272e0cd

File tree

5 files changed

+284
-68
lines changed

5 files changed

+284
-68
lines changed

internal/jsonrpc2/conn.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
361361
if err := c.write(ctx, call); err != nil {
362362
// Sending failed. We will never get a response, so deliver a fake one if it
363363
// wasn't already retired by the connection breaking.
364-
c.updateInFlight(func(s *inFlightState) {
365-
if s.outgoingCalls[ac.id] == ac {
366-
delete(s.outgoingCalls, ac.id)
367-
ac.retire(&Response{ID: id, Error: err})
368-
} else {
369-
// ac was already retired by the readIncoming goroutine:
370-
// perhaps our write raced with the Read side of the connection breaking.
371-
}
372-
})
364+
c.Retire(ac, err)
373365
}
374366
return ac
375367
}
376368

369+
// Retire stops tracking the call, and reports err as its terminal error.
370+
//
371+
// Retire is safe to call multiple times: if the call is already no longer
372+
// tracked, Retire is a no op.
373+
func (c *Connection) Retire(ac *AsyncCall, err error) {
374+
c.updateInFlight(func(s *inFlightState) {
375+
if s.outgoingCalls[ac.id] == ac {
376+
delete(s.outgoingCalls, ac.id)
377+
ac.retire(&Response{ID: ac.id, Error: err})
378+
} else {
379+
// ac was already retired elsewhere.
380+
}
381+
})
382+
}
383+
377384
// Async, signals that the current jsonrpc2 request may be handled
378385
// asynchronously to subsequent requests, when ctx is the request context.
379386
//
@@ -437,6 +444,9 @@ func (ac *AsyncCall) IsReady() bool {
437444
}
438445

439446
// retire processes the response to the call.
447+
//
448+
// It is an error to call retire more than once: retire is guarded by the
449+
// connection's outgoingCalls map.
440450
func (ac *AsyncCall) retire(response *Response) {
441451
select {
442452
case <-ac.ready:
@@ -450,6 +460,9 @@ func (ac *AsyncCall) retire(response *Response) {
450460

451461
// Await waits for (and decodes) the results of a Call.
452462
// The response will be unmarshaled from JSON into the result.
463+
//
464+
// If the call is cancelled due to context cancellation, the result is
465+
// ctx.Err().
453466
func (ac *AsyncCall) Await(ctx context.Context, result any) error {
454467
select {
455468
case <-ctx.Done():

mcp/streamable.go

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/modelcontextprotocol/go-sdk/auth"
2727
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
28+
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
2829
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2930
)
3031

@@ -1336,12 +1337,17 @@ const (
13361337
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
13371338
// It must be 1.0 or greater if MaxRetries is greater than 0.
13381339
reconnectGrowFactor = 1.5
1339-
// reconnectInitialDelay is the base delay for the first reconnect attempt.
1340-
reconnectInitialDelay = 1 * time.Second
13411340
// reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
13421341
reconnectMaxDelay = 30 * time.Second
13431342
)
13441343

1344+
var (
1345+
// reconnectInitialDelay is the base delay for the first reconnect attempt.
1346+
//
1347+
// Mutable for testing.
1348+
reconnectInitialDelay = 1 * time.Second
1349+
)
1350+
13451351
// Connect implements the [Transport] interface.
13461352
//
13471353
// The resulting [Connection] writes messages via POST requests to the
@@ -1364,7 +1370,20 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
13641370
// Create a new cancellable context that will manage the connection's lifecycle.
13651371
// This is crucial for cleanly shutting down the background SSE listener by
13661372
// cancelling its blocking network operations, which prevents hangs on exit.
1367-
connCtx, cancel := context.WithCancel(ctx)
1373+
//
1374+
// This context should be detached from the incoming context: the standalone
1375+
// SSE request should not break when the connection context is done.
1376+
//
1377+
// For example, consider that the user may want to wait at most 5s to connect
1378+
// to the server, and therefore uses a context with a 5s timeout when calling
1379+
// client.Connect. Let's suppose that Connect returns after 1s, and the user
1380+
// starts using the resulting session. If we didn't detach here, the session
1381+
// would break after 4s, when the background SSE stream is terminated.
1382+
//
1383+
// Instead, creating a cancellable context detached from the incoming context
1384+
// allows us to preserve context values (which may be necessary for auth
1385+
// middleware), yet only cancel the standalone stream when the connection is closed.
1386+
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
13681387
conn := &streamableClientConn{
13691388
url: t.Endpoint,
13701389
client: client,
@@ -1383,8 +1402,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
13831402
type streamableClientConn struct {
13841403
url string
13851404
client *http.Client
1386-
ctx context.Context
1387-
cancel context.CancelFunc
1405+
ctx context.Context // connection context, detached from Connect
1406+
cancel context.CancelFunc // cancels ctx
13881407
incoming chan jsonrpc.Message
13891408
maxRetries int
13901409
strict bool // from [StreamableClientTransport.strict]
@@ -1447,9 +1466,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
14471466
}
14481467

14491468
func (c *streamableClientConn) connectStandaloneSSE() {
1450-
resp, err := c.connectSSE("", 0)
1469+
resp, err := c.connectSSE(c.ctx, "", 0, true)
14511470
if err != nil {
1452-
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1471+
// If the client didn't cancel the request, and failure breaks the logical
1472+
// session.
1473+
if c.ctx.Err() == nil {
1474+
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1475+
}
14531476
return
14541477
}
14551478

@@ -1481,7 +1504,7 @@ func (c *streamableClientConn) connectStandaloneSSE() {
14811504
c.fail(err)
14821505
return
14831506
}
1484-
go c.handleSSE(summary, resp, true, nil)
1507+
go c.handleSSE(c.ctx, summary, resp, true, nil)
14851508
}
14861509

14871510
// fail handles an asynchronous error while reading.
@@ -1616,7 +1639,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
16161639
forCall = jsonReq
16171640
}
16181641
// TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1619-
go c.handleSSE(requestSummary, resp, false, forCall)
1642+
go c.handleSSE(ctx, requestSummary, resp, false, forCall)
16201643

16211644
default:
16221645
resp.Body.Close()
@@ -1668,15 +1691,17 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
16681691
//
16691692
// If forCall is set, it is the call that initiated the stream, and the
16701693
// stream is complete when we receive its response.
1671-
func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
1694+
func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
16721695
for {
16731696
// Connection was successful. Continue the loop with the new response.
1674-
// TODO: we should set a reasonable limit on the number of times we'll try
1675-
// getting a response for a given request.
1697+
//
1698+
// TODO(#679): we should set a reasonable limit on the number of times
1699+
// we'll try getting a response for a given request, or enforce that we
1700+
// actually make progress.
16761701
//
16771702
// Eventually, if we don't get the response, we should stop trying and
16781703
// fail the request.
1679-
lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall)
1704+
lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall)
16801705

16811706
// If the connection was closed by the client, we're done.
16821707
if clientClosed {
@@ -1689,12 +1714,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo
16891714
}
16901715

16911716
// The stream was interrupted or ended by the server. Attempt to reconnect.
1692-
newResp, err := c.connectSSE(lastEventID, reconnectDelay)
1717+
newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false)
16931718
if err != nil {
1694-
// All reconnection attempts failed: fail the connection.
1695-
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
1719+
// If the client didn't cancel this request, any failure to execute it
1720+
// breaks the logical MCP session.
1721+
if ctx.Err() == nil {
1722+
// All reconnection attempts failed: fail the connection.
1723+
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
1724+
}
16961725
return
16971726
}
1727+
16981728
resp = newResp
16991729
if err := c.checkResponse(requestSummary, resp); err != nil {
17001730
c.fail(err)
@@ -1731,11 +1761,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
17311761
// incoming channel. It returns the ID of the last processed event and a flag
17321762
// indicating if the connection was closed by the client. If resp is nil, it
17331763
// returns "", false.
1734-
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
1764+
func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
17351765
defer resp.Body.Close()
17361766
for evt, err := range scanEvents(resp.Body) {
17371767
if err != nil {
1738-
// TODO: we should differentiate EOF from other errors here.
1768+
if ctx.Err() != nil {
1769+
return "", 0, true // don't reconnect: client cancelled
1770+
}
17391771
break
17401772
}
17411773

@@ -1768,6 +1800,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
17681800
return "", 0, true
17691801
}
17701802
}
1803+
17711804
case <-c.done:
17721805
// The connection was closed by the client; exit gracefully.
17731806
return "", 0, true
@@ -1777,6 +1810,9 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
17771810
//
17781811
// If the lastEventID is "", the stream is not retryable and we should
17791812
// report a synthetic error for the call.
1813+
//
1814+
// Note that this is different from the cancellation case above, since the
1815+
// caller is still waiting for a response that will never come.
17801816
if lastEventID == "" && forCall != nil {
17811817
errmsg := &jsonrpc2.Response{
17821818
ID: forCall.ID,
@@ -1800,12 +1836,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
18001836
//
18011837
// reconnectDelay is the delay set by the server using the SSE retry field, or
18021838
// 0.
1803-
func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) {
1839+
//
1840+
// If initial is set, this is the initial attempt.
1841+
//
1842+
// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()).
1843+
func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) {
18041844
var finalErr error
1805-
// If lastEventID is set, we've already connected successfully once, so
1806-
// consider that to be the first attempt.
18071845
attempt := 0
1808-
if lastEventID != "" {
1846+
if !initial {
1847+
// We've already connected successfully once, so delay subsequent
1848+
// reconnections. Otherwise, if the server returns 200 but terminates the
1849+
// connection, we'll reconnect as fast as we can, ad infinitum.
1850+
//
1851+
// TODO: we should consider also setting a limit on total attempts for one
1852+
// logical request.
18091853
attempt = 1
18101854
}
18111855
delay := calculateReconnectDelay(attempt)
@@ -1816,16 +1860,14 @@ func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay tim
18161860
select {
18171861
case <-c.done:
18181862
return nil, fmt.Errorf("connection closed by client during reconnect")
1819-
case <-c.ctx.Done():
1863+
1864+
case <-ctx.Done():
18201865
// If the connection context is canceled, the request below will not
18211866
// succeed anyway.
1822-
//
1823-
// TODO(#662): we should not be using the connection context for
1824-
// reconnection: we should instead be using the call context (from
1825-
// Write).
1826-
return nil, fmt.Errorf("connection context closed")
1867+
return nil, ctx.Err()
1868+
18271869
case <-time.After(delay):
1828-
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1870+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
18291871
if err != nil {
18301872
return nil, err
18311873
}

0 commit comments

Comments
 (0)