Skip to content

Commit 2f9dd49

Browse files
authored
fix: add thread-safe accessor methods to mockRecorderClient (#80)
Fixes #59
1 parent 9aec90b commit 2f9dd49

File tree

1 file changed

+86
-37
lines changed

1 file changed

+86
-37
lines changed

bridge_integration_test.go

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,24 @@ func TestAnthropicMessages(t *testing.T) {
171171
// One for message_start, one for message_delta.
172172
expectedTokenRecordings = 2
173173
}
174-
require.Len(t, recorderClient.tokenUsages, expectedTokenRecordings)
174+
tokenUsages := recorderClient.RecordedTokenUsages()
175+
require.Len(t, tokenUsages, expectedTokenRecordings)
175176

176-
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
177-
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
177+
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated")
178+
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated")
178179

179-
require.Len(t, recorderClient.toolUsages, 1)
180-
assert.Equal(t, "Read", recorderClient.toolUsages[0].Tool)
181-
require.IsType(t, json.RawMessage{}, recorderClient.toolUsages[0].Args)
180+
toolUsages := recorderClient.RecordedToolUsages()
181+
require.Len(t, toolUsages, 1)
182+
assert.Equal(t, "Read", toolUsages[0].Tool)
183+
require.IsType(t, json.RawMessage{}, toolUsages[0].Args)
182184
var args map[string]any
183-
require.NoError(t, json.Unmarshal(recorderClient.toolUsages[0].Args.(json.RawMessage), &args))
185+
require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args))
184186
require.Contains(t, args, "file_path")
185187
assert.Equal(t, "/tmp/blah/foo", args["file_path"])
186188

187-
require.Len(t, recorderClient.userPrompts, 1)
188-
assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt)
189+
promptUsages := recorderClient.RecordedPromptUsages()
190+
require.Len(t, promptUsages, 1)
191+
assert.Equal(t, "read the foo file", promptUsages[0].Prompt)
189192

190193
recorderClient.verifyAllInterceptionsEnded(t)
191194
})
@@ -346,8 +349,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
346349
// and the interception data.
347350
require.Equal(t, requestCount, 1)
348351
require.Equal(t, bedrockCfg.Model, receivedModelName)
349-
require.Len(t, recorderClient.interceptions, 1)
350-
require.Equal(t, recorderClient.interceptions[0].Model, bedrockCfg.Model)
352+
interceptions := recorderClient.RecordedInterceptions()
353+
require.Len(t, interceptions, 1)
354+
require.Equal(t, interceptions[0].Model, bedrockCfg.Model)
351355
recorderClient.verifyAllInterceptionsEnded(t)
352356
})
353357
}
@@ -437,18 +441,21 @@ func TestOpenAIChatCompletions(t *testing.T) {
437441
assert.Equal(t, "[DONE]", lastEvent.Data)
438442
}
439443

440-
require.Len(t, recorderClient.tokenUsages, 1)
441-
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
442-
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
444+
tokenUsages := recorderClient.RecordedTokenUsages()
445+
require.Len(t, tokenUsages, 1)
446+
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated")
447+
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated")
443448

444-
require.Len(t, recorderClient.toolUsages, 1)
445-
assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool)
446-
require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args)
447-
require.Contains(t, recorderClient.toolUsages[0].Args, "path")
448-
assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"])
449+
toolUsages := recorderClient.RecordedToolUsages()
450+
require.Len(t, toolUsages, 1)
451+
assert.Equal(t, "read_file", toolUsages[0].Tool)
452+
require.IsType(t, map[string]any{}, toolUsages[0].Args)
453+
require.Contains(t, toolUsages[0].Args, "path")
454+
assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"])
449455

450-
require.Len(t, recorderClient.userPrompts, 1)
451-
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
456+
promptUsages := recorderClient.RecordedPromptUsages()
457+
require.Len(t, promptUsages, 1)
458+
assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt)
452459

453460
recorderClient.verifyAllInterceptionsEnded(t)
454461
})
@@ -605,8 +612,9 @@ func TestSimple(t *testing.T) {
605612
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
606613

607614
// Then: I expect the prompt to have been tracked.
608-
require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked")
609-
assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin")
615+
promptUsages := recorderClient.RecordedPromptUsages()
616+
require.NotEmpty(t, promptUsages, "no prompts tracked")
617+
assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin")
610618

611619
// Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider.
612620
// The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
@@ -615,8 +623,9 @@ func TestSimple(t *testing.T) {
615623
require.NoError(t, err, "failed to retrieve response ID")
616624
require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id)
617625

618-
require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1)
619-
require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID)
626+
tokenUsages := recorderClient.RecordedTokenUsages()
627+
require.GreaterOrEqual(t, len(tokenUsages), 1)
628+
require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID)
620629

621630
recorderClient.verifyAllInterceptionsEnded(t)
622631
})
@@ -770,11 +779,12 @@ func TestAnthropicInjectedTools(t *testing.T) {
770779
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
771780

772781
// Ensure expected tool was invoked with expected input.
773-
require.Len(t, recorderClient.toolUsages, 1)
774-
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
782+
toolUsages := recorderClient.RecordedToolUsages()
783+
require.Len(t, toolUsages, 1)
784+
require.Equal(t, mockToolName, toolUsages[0].Tool)
775785
expected, err := json.Marshal(map[string]any{"owner": "admin"})
776786
require.NoError(t, err)
777-
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
787+
actual, err := json.Marshal(toolUsages[0].Args)
778788
require.NoError(t, err)
779789
require.EqualValues(t, expected, actual)
780790
invocations := mcpCalls.getCallsByTool(mockToolName)
@@ -831,11 +841,13 @@ func TestAnthropicInjectedTools(t *testing.T) {
831841
assert.EqualValues(t, 204, message.Usage.OutputTokens)
832842

833843
// Ensure tokens used during injected tool invocation are accounted for.
834-
assert.EqualValues(t, 15308, calculateTotalInputTokens(recorderClient.tokenUsages))
835-
assert.EqualValues(t, 204, calculateTotalOutputTokens(recorderClient.tokenUsages))
844+
tokenUsages := recorderClient.RecordedTokenUsages()
845+
assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages))
846+
assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages))
836847

837848
// Ensure we received exactly one prompt.
838-
require.Len(t, recorderClient.userPrompts, 1)
849+
promptUsages := recorderClient.RecordedPromptUsages()
850+
require.Len(t, promptUsages, 1)
839851
})
840852
}
841853
}
@@ -857,11 +869,12 @@ func TestOpenAIInjectedTools(t *testing.T) {
857869
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
858870

859871
// Ensure expected tool was invoked with expected input.
860-
require.Len(t, recorderClient.toolUsages, 1)
861-
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
872+
toolUsages := recorderClient.RecordedToolUsages()
873+
require.Len(t, toolUsages, 1)
874+
require.Equal(t, mockToolName, toolUsages[0].Tool)
862875
expected, err := json.Marshal(map[string]any{"owner": "admin"})
863876
require.NoError(t, err)
864-
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
877+
actual, err := json.Marshal(toolUsages[0].Args)
865878
require.NoError(t, err)
866879
require.EqualValues(t, expected, actual)
867880
invocations := mcpCalls.getCallsByTool(mockToolName)
@@ -933,11 +946,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
933946
assert.EqualValues(t, 105, message.Usage.CompletionTokens)
934947

935948
// Ensure tokens used during injected tool invocation are accounted for.
936-
require.EqualValues(t, 5047, calculateTotalInputTokens(recorderClient.tokenUsages))
937-
require.EqualValues(t, 105, calculateTotalOutputTokens(recorderClient.tokenUsages))
949+
tokenUsages := recorderClient.RecordedTokenUsages()
950+
require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages))
951+
require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages))
938952

939953
// Ensure we received exactly one prompt.
940-
require.Len(t, recorderClient.userPrompts, 1)
954+
promptUsages := recorderClient.RecordedPromptUsages()
955+
require.Len(t, promptUsages, 1)
941956
})
942957
}
943958
}
@@ -1822,6 +1837,40 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
18221837
return nil
18231838
}
18241839

1840+
// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner.
1841+
// Note: This is a shallow clone - the slice is copied but the pointers reference the
1842+
// same underlying records. This is sufficient for our test assertions which only read
1843+
// the data and don't modify the records.
1844+
func (m *mockRecorderClient) RecordedTokenUsages() []*aibridge.TokenUsageRecord {
1845+
m.mu.Lock()
1846+
defer m.mu.Unlock()
1847+
return slices.Clone(m.tokenUsages)
1848+
}
1849+
1850+
// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner.
1851+
// Note: This is a shallow clone (see RecordedTokenUsages for details).
1852+
func (m *mockRecorderClient) RecordedPromptUsages() []*aibridge.PromptUsageRecord {
1853+
m.mu.Lock()
1854+
defer m.mu.Unlock()
1855+
return slices.Clone(m.userPrompts)
1856+
}
1857+
1858+
// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner.
1859+
// Note: This is a shallow clone (see RecordedTokenUsages for details).
1860+
func (m *mockRecorderClient) RecordedToolUsages() []*aibridge.ToolUsageRecord {
1861+
m.mu.Lock()
1862+
defer m.mu.Unlock()
1863+
return slices.Clone(m.toolUsages)
1864+
}
1865+
1866+
// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner.
1867+
// Note: This is a shallow clone (see RecordedTokenUsages for details).
1868+
func (m *mockRecorderClient) RecordedInterceptions() []*aibridge.InterceptionRecord {
1869+
m.mu.Lock()
1870+
defer m.mu.Unlock()
1871+
return slices.Clone(m.interceptions)
1872+
}
1873+
18251874
// verify all recorded interceptions has been marked as completed
18261875
func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
18271876
t.Helper()

0 commit comments

Comments
 (0)