Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/coder/aibridge"
"github.com/coder/aibridge/mcp"
"github.com/google/uuid"
Expand Down Expand Up @@ -1352,6 +1353,152 @@ func TestStableRequestEncoding(t *testing.T) {
}
}

// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is
// correctly disabled based on the tool_choice parameter in the request.
// See https://github.com/coder/aibridge/issues/2
func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
t.Parallel()

var (
toolChoiceAuto = string(constant.ValueOf[constant.Auto]())
toolChoiceAny = string(constant.ValueOf[constant.Any]())
toolChoiceNone = string(constant.ValueOf[constant.None]())
toolChoiceTool = string(constant.ValueOf[constant.Tool]())
)

cases := []struct {
name string
toolChoice any // nil, or map with "type" key.
expectDisableParallel bool
expectToolChoiceTypeInRequest string
}{
{
name: "no tool_choice defined defaults to auto",
toolChoice: nil,
expectDisableParallel: true,
expectToolChoiceTypeInRequest: toolChoiceAuto,
},
{
name: "tool_choice auto",
toolChoice: map[string]any{"type": toolChoiceAuto},
expectDisableParallel: true,
expectToolChoiceTypeInRequest: toolChoiceAuto,
},
{
name: "tool_choice any",
toolChoice: map[string]any{"type": toolChoiceAny},
expectDisableParallel: true,
expectToolChoiceTypeInRequest: toolChoiceAny,
},
{
name: "tool_choice tool",
toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"},
expectDisableParallel: true,
expectToolChoiceTypeInRequest: toolChoiceTool,
},
{
name: "tool_choice none",
toolChoice: map[string]any{"type": toolChoiceNone},
expectDisableParallel: false,
expectToolChoiceTypeInRequest: toolChoiceNone,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Configure the bridge.
mcpMgr := mcp.NewServerProxyManager(nil, testTracer)
require.NoError(t, mcpMgr.Init(ctx))

arc := txtar.Parse(antSimple)
files := filesMap(arc)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureNonStreamingResponse)

// Prepare request body with tool_choice set.
var reqJSON map[string]any
require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqJSON))
if tc.toolChoice != nil {
reqJSON["tool_choice"] = tc.toolChoice
}
reqBody, err := json.Marshal(reqJSON)
require.NoError(t, err)

var receivedRequest map[string]any

// Create a mock server that captures the request body sent upstream.
mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture the raw request body.
raw, err := io.ReadAll(r.Body)
defer r.Body.Close()
require.NoError(t, err)

require.NoError(t, json.Unmarshal(raw, &receivedRequest))

// Return a valid API response.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(files[fixtureNonStreamingResponse])
}))
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
mockSrv.Start()
t.Cleanup(mockSrv.Close)

recorder := &mockRecorderClient{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(mockSrv.URL, apiKey), nil)}
bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer)
require.NoError(t, err)

// Invoke request to mocked API via aibridge.
bridgeSrv := httptest.NewUnstartedServer(bridge)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody)
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()

// Verify tool_choice in the upstream request.
require.NotNil(t, receivedRequest)
toolChoice, ok := receivedRequest["tool_choice"].(map[string]any)
require.True(t, ok, "expected tool_choice in upstream request")

// Verify the type matches expectation.
assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"])

// Verify name is preserved for tool_choice=tool.
if tc.expectToolChoiceTypeInRequest == toolChoiceTool {
assert.Equal(t, "some_tool", toolChoice["name"])
}

// Verify disable_parallel_tool_use based on expectations.
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use
disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool)

if tc.expectDisableParallel {
require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice")
assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true")
} else {
assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set")
}
})
}
}

func TestEnvironmentDoNotLeak(t *testing.T) {
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.

Expand Down
29 changes: 24 additions & 5 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/coder/aibridge/mcp"
Expand Down Expand Up @@ -96,11 +97,29 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() {
}

// Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches.
i.req.ToolChoice = anthropic.ToolChoiceUnionParam{
OfAny: &anthropic.ToolChoiceAnyParam{
Type: "auto",
DisableParallelToolUse: anthropic.Bool(true),
},
// https://github.com/coder/aibridge/issues/2
toolChoiceType := i.req.ToolChoice.GetType()
var toolChoiceTypeStr string
if toolChoiceType != nil {
toolChoiceTypeStr = *toolChoiceType
}

switch toolChoiceTypeStr {
// If no tool_choice was defined, assume auto.
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use.
case "", string(constant.ValueOf[constant.Auto]()):
// We only set OfAuto if no tool_choice was provided (the default).
// "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it.
if i.req.ToolChoice.OfAuto == nil {
i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{}
}
i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.Any]()):
i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.Tool]()):
i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.None]()):
// No-op; if tool_choice=none then tools are not used at all.
}
}

Expand Down