@@ -22,6 +22,7 @@ import (
2222 "cdr.dev/slog/sloggers/slogtest"
2323 "github.com/anthropics/anthropic-sdk-go"
2424 "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
25+ "github.com/anthropics/anthropic-sdk-go/shared/constant"
2526 "github.com/coder/aibridge"
2627 "github.com/coder/aibridge/mcp"
2728 "github.com/google/uuid"
@@ -1352,6 +1353,152 @@ func TestStableRequestEncoding(t *testing.T) {
13521353 }
13531354}
13541355
1356+ // TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is
1357+ // correctly disabled based on the tool_choice parameter in the request.
1358+ // See https://github.com/coder/aibridge/issues/2
1359+ func TestAnthropicToolChoiceParallelDisabled (t * testing.T ) {
1360+ t .Parallel ()
1361+
1362+ var (
1363+ toolChoiceAuto = string (constant .ValueOf [constant.Auto ]())
1364+ toolChoiceAny = string (constant .ValueOf [constant.Any ]())
1365+ toolChoiceNone = string (constant .ValueOf [constant.None ]())
1366+ toolChoiceTool = string (constant .ValueOf [constant.Tool ]())
1367+ )
1368+
1369+ cases := []struct {
1370+ name string
1371+ toolChoice any // nil, or map with "type" key.
1372+ expectDisableParallel bool
1373+ expectToolChoiceTypeInRequest string
1374+ }{
1375+ {
1376+ name : "no tool_choice defined defaults to auto" ,
1377+ toolChoice : nil ,
1378+ expectDisableParallel : true ,
1379+ expectToolChoiceTypeInRequest : toolChoiceAuto ,
1380+ },
1381+ {
1382+ name : "tool_choice auto" ,
1383+ toolChoice : map [string ]any {"type" : toolChoiceAuto },
1384+ expectDisableParallel : true ,
1385+ expectToolChoiceTypeInRequest : toolChoiceAuto ,
1386+ },
1387+ {
1388+ name : "tool_choice any" ,
1389+ toolChoice : map [string ]any {"type" : toolChoiceAny },
1390+ expectDisableParallel : true ,
1391+ expectToolChoiceTypeInRequest : toolChoiceAny ,
1392+ },
1393+ {
1394+ name : "tool_choice tool" ,
1395+ toolChoice : map [string ]any {"type" : toolChoiceTool , "name" : "some_tool" },
1396+ expectDisableParallel : true ,
1397+ expectToolChoiceTypeInRequest : toolChoiceTool ,
1398+ },
1399+ {
1400+ name : "tool_choice none" ,
1401+ toolChoice : map [string ]any {"type" : toolChoiceNone },
1402+ expectDisableParallel : false ,
1403+ expectToolChoiceTypeInRequest : toolChoiceNone ,
1404+ },
1405+ }
1406+
1407+ for _ , tc := range cases {
1408+ t .Run (tc .name , func (t * testing.T ) {
1409+ t .Parallel ()
1410+
1411+ ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
1412+ t .Cleanup (cancel )
1413+
1414+ // Configure the bridge.
1415+ mcpMgr := mcp .NewServerProxyManager (nil , testTracer )
1416+ require .NoError (t , mcpMgr .Init (ctx ))
1417+
1418+ arc := txtar .Parse (antSimple )
1419+ files := filesMap (arc )
1420+ require .Contains (t , files , fixtureRequest )
1421+ require .Contains (t , files , fixtureNonStreamingResponse )
1422+
1423+ // Prepare request body with tool_choice set.
1424+ var reqJSON map [string ]any
1425+ require .NoError (t , json .Unmarshal (files [fixtureRequest ], & reqJSON ))
1426+ if tc .toolChoice != nil {
1427+ reqJSON ["tool_choice" ] = tc .toolChoice
1428+ }
1429+ reqBody , err := json .Marshal (reqJSON )
1430+ require .NoError (t , err )
1431+
1432+ var receivedRequest map [string ]any
1433+
1434+ // Create a mock server that captures the request body sent upstream.
1435+ mockSrv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1436+ // Capture the raw request body.
1437+ raw , err := io .ReadAll (r .Body )
1438+ defer r .Body .Close ()
1439+ require .NoError (t , err )
1440+
1441+ require .NoError (t , json .Unmarshal (raw , & receivedRequest ))
1442+
1443+ // Return a valid API response.
1444+ w .Header ().Set ("Content-Type" , "application/json" )
1445+ w .WriteHeader (http .StatusOK )
1446+ _ , _ = w .Write (files [fixtureNonStreamingResponse ])
1447+ }))
1448+ mockSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
1449+ return ctx
1450+ }
1451+ mockSrv .Start ()
1452+ t .Cleanup (mockSrv .Close )
1453+
1454+ recorder := & mockRecorderClient {}
1455+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1456+ providers := []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (mockSrv .URL , apiKey ), nil )}
1457+ bridge , err := aibridge .NewRequestBridge (ctx , providers , recorder , mcpMgr , logger , nil , testTracer )
1458+ require .NoError (t , err )
1459+
1460+ // Invoke request to mocked API via aibridge.
1461+ bridgeSrv := httptest .NewUnstartedServer (bridge )
1462+ bridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
1463+ return aibridge .AsActor (ctx , userID , nil )
1464+ }
1465+ bridgeSrv .Start ()
1466+ t .Cleanup (bridgeSrv .Close )
1467+
1468+ req := createAnthropicMessagesReq (t , bridgeSrv .URL , reqBody )
1469+ client := & http.Client {}
1470+ resp , err := client .Do (req )
1471+ require .NoError (t , err )
1472+ require .Equal (t , http .StatusOK , resp .StatusCode )
1473+ _ = resp .Body .Close ()
1474+
1475+ // Verify tool_choice in the upstream request.
1476+ require .NotNil (t , receivedRequest )
1477+ toolChoice , ok := receivedRequest ["tool_choice" ].(map [string ]any )
1478+ require .True (t , ok , "expected tool_choice in upstream request" )
1479+
1480+ // Verify the type matches expectation.
1481+ assert .Equal (t , tc .expectToolChoiceTypeInRequest , toolChoice ["type" ])
1482+
1483+ // Verify name is preserved for tool_choice=tool.
1484+ if tc .expectToolChoiceTypeInRequest == toolChoiceTool {
1485+ assert .Equal (t , "some_tool" , toolChoice ["name" ])
1486+ }
1487+
1488+ // Verify disable_parallel_tool_use based on expectations.
1489+ // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use
1490+ disableParallel , hasDisableParallel := toolChoice ["disable_parallel_tool_use" ].(bool )
1491+
1492+ if tc .expectDisableParallel {
1493+ require .True (t , hasDisableParallel , "expected disable_parallel_tool_use in tool_choice" )
1494+ assert .True (t , disableParallel , "expected disable_parallel_tool_use to be true" )
1495+ } else {
1496+ assert .False (t , hasDisableParallel , "expected disable_parallel_tool_use to not be set" )
1497+ }
1498+ })
1499+ }
1500+ }
1501+
13551502func TestEnvironmentDoNotLeak (t * testing.T ) {
13561503 // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
13571504
0 commit comments