Skip to content

Commit 9aec90b

Browse files
authored
fix: correctly set disable_parallel_tool_use (#78)
Signed-off-by: Danny Kopping <[email protected]>
1 parent b202549 commit 9aec90b

File tree

2 files changed

+171
-5
lines changed

2 files changed

+171
-5
lines changed

bridge_integration_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"cdr.dev/slog/sloggers/slogtest"
2323
"github.com/anthropics/anthropic-sdk-go"
2424
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
25+
"github.com/anthropics/anthropic-sdk-go/shared/constant"
2526
"github.com/coder/aibridge"
2627
"github.com/coder/aibridge/mcp"
2728
"github.com/google/uuid"
@@ -1352,6 +1353,152 @@ func TestStableRequestEncoding(t *testing.T) {
13521353
}
13531354
}
13541355

1356+
// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is
1357+
// correctly disabled based on the tool_choice parameter in the request.
1358+
// See https://github.com/coder/aibridge/issues/2
1359+
func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
1360+
t.Parallel()
1361+
1362+
var (
1363+
toolChoiceAuto = string(constant.ValueOf[constant.Auto]())
1364+
toolChoiceAny = string(constant.ValueOf[constant.Any]())
1365+
toolChoiceNone = string(constant.ValueOf[constant.None]())
1366+
toolChoiceTool = string(constant.ValueOf[constant.Tool]())
1367+
)
1368+
1369+
cases := []struct {
1370+
name string
1371+
toolChoice any // nil, or map with "type" key.
1372+
expectDisableParallel bool
1373+
expectToolChoiceTypeInRequest string
1374+
}{
1375+
{
1376+
name: "no tool_choice defined defaults to auto",
1377+
toolChoice: nil,
1378+
expectDisableParallel: true,
1379+
expectToolChoiceTypeInRequest: toolChoiceAuto,
1380+
},
1381+
{
1382+
name: "tool_choice auto",
1383+
toolChoice: map[string]any{"type": toolChoiceAuto},
1384+
expectDisableParallel: true,
1385+
expectToolChoiceTypeInRequest: toolChoiceAuto,
1386+
},
1387+
{
1388+
name: "tool_choice any",
1389+
toolChoice: map[string]any{"type": toolChoiceAny},
1390+
expectDisableParallel: true,
1391+
expectToolChoiceTypeInRequest: toolChoiceAny,
1392+
},
1393+
{
1394+
name: "tool_choice tool",
1395+
toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"},
1396+
expectDisableParallel: true,
1397+
expectToolChoiceTypeInRequest: toolChoiceTool,
1398+
},
1399+
{
1400+
name: "tool_choice none",
1401+
toolChoice: map[string]any{"type": toolChoiceNone},
1402+
expectDisableParallel: false,
1403+
expectToolChoiceTypeInRequest: toolChoiceNone,
1404+
},
1405+
}
1406+
1407+
for _, tc := range cases {
1408+
t.Run(tc.name, func(t *testing.T) {
1409+
t.Parallel()
1410+
1411+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1412+
t.Cleanup(cancel)
1413+
1414+
// Configure the bridge.
1415+
mcpMgr := mcp.NewServerProxyManager(nil, testTracer)
1416+
require.NoError(t, mcpMgr.Init(ctx))
1417+
1418+
arc := txtar.Parse(antSimple)
1419+
files := filesMap(arc)
1420+
require.Contains(t, files, fixtureRequest)
1421+
require.Contains(t, files, fixtureNonStreamingResponse)
1422+
1423+
// Prepare request body with tool_choice set.
1424+
var reqJSON map[string]any
1425+
require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqJSON))
1426+
if tc.toolChoice != nil {
1427+
reqJSON["tool_choice"] = tc.toolChoice
1428+
}
1429+
reqBody, err := json.Marshal(reqJSON)
1430+
require.NoError(t, err)
1431+
1432+
var receivedRequest map[string]any
1433+
1434+
// Create a mock server that captures the request body sent upstream.
1435+
mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1436+
// Capture the raw request body.
1437+
raw, err := io.ReadAll(r.Body)
1438+
defer r.Body.Close()
1439+
require.NoError(t, err)
1440+
1441+
require.NoError(t, json.Unmarshal(raw, &receivedRequest))
1442+
1443+
// Return a valid API response.
1444+
w.Header().Set("Content-Type", "application/json")
1445+
w.WriteHeader(http.StatusOK)
1446+
_, _ = w.Write(files[fixtureNonStreamingResponse])
1447+
}))
1448+
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1449+
return ctx
1450+
}
1451+
mockSrv.Start()
1452+
t.Cleanup(mockSrv.Close)
1453+
1454+
recorder := &mockRecorderClient{}
1455+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1456+
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(mockSrv.URL, apiKey), nil)}
1457+
bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer)
1458+
require.NoError(t, err)
1459+
1460+
// Invoke request to mocked API via aibridge.
1461+
bridgeSrv := httptest.NewUnstartedServer(bridge)
1462+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1463+
return aibridge.AsActor(ctx, userID, nil)
1464+
}
1465+
bridgeSrv.Start()
1466+
t.Cleanup(bridgeSrv.Close)
1467+
1468+
req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody)
1469+
client := &http.Client{}
1470+
resp, err := client.Do(req)
1471+
require.NoError(t, err)
1472+
require.Equal(t, http.StatusOK, resp.StatusCode)
1473+
_ = resp.Body.Close()
1474+
1475+
// Verify tool_choice in the upstream request.
1476+
require.NotNil(t, receivedRequest)
1477+
toolChoice, ok := receivedRequest["tool_choice"].(map[string]any)
1478+
require.True(t, ok, "expected tool_choice in upstream request")
1479+
1480+
// Verify the type matches expectation.
1481+
assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"])
1482+
1483+
// Verify name is preserved for tool_choice=tool.
1484+
if tc.expectToolChoiceTypeInRequest == toolChoiceTool {
1485+
assert.Equal(t, "some_tool", toolChoice["name"])
1486+
}
1487+
1488+
// Verify disable_parallel_tool_use based on expectations.
1489+
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use
1490+
disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool)
1491+
1492+
if tc.expectDisableParallel {
1493+
require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice")
1494+
assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true")
1495+
} else {
1496+
assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set")
1497+
}
1498+
})
1499+
}
1500+
}
1501+
13551502
func TestEnvironmentDoNotLeak(t *testing.T) {
13561503
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
13571504

intercept_anthropic_messages_base.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/anthropics/anthropic-sdk-go"
1313
"github.com/anthropics/anthropic-sdk-go/bedrock"
1414
"github.com/anthropics/anthropic-sdk-go/option"
15+
"github.com/anthropics/anthropic-sdk-go/shared/constant"
1516
"github.com/aws/aws-sdk-go-v2/config"
1617
"github.com/aws/aws-sdk-go-v2/credentials"
1718
"github.com/coder/aibridge/mcp"
@@ -96,11 +97,29 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() {
9697
}
9798

9899
// Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches.
99-
i.req.ToolChoice = anthropic.ToolChoiceUnionParam{
100-
OfAny: &anthropic.ToolChoiceAnyParam{
101-
Type: "auto",
102-
DisableParallelToolUse: anthropic.Bool(true),
103-
},
100+
// https://github.com/coder/aibridge/issues/2
101+
toolChoiceType := i.req.ToolChoice.GetType()
102+
var toolChoiceTypeStr string
103+
if toolChoiceType != nil {
104+
toolChoiceTypeStr = *toolChoiceType
105+
}
106+
107+
switch toolChoiceTypeStr {
108+
// If no tool_choice was defined, assume auto.
109+
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use.
110+
case "", string(constant.ValueOf[constant.Auto]()):
111+
// We only set OfAuto if no tool_choice was provided (the default).
112+
// "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it.
113+
if i.req.ToolChoice.OfAuto == nil {
114+
i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{}
115+
}
116+
i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true)
117+
case string(constant.ValueOf[constant.Any]()):
118+
i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true)
119+
case string(constant.ValueOf[constant.Tool]()):
120+
i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true)
121+
case string(constant.ValueOf[constant.None]()):
122+
// No-op; if tool_choice=none then tools are not used at all.
104123
}
105124
}
106125

0 commit comments

Comments
 (0)