@@ -171,21 +171,24 @@ func TestAnthropicMessages(t *testing.T) {
171171 // One for message_start, one for message_delta.
172172 expectedTokenRecordings = 2
173173 }
174- require .Len (t , recorderClient .tokenUsages , expectedTokenRecordings )
174+ tokenUsages := recorderClient .RecordedTokenUsages ()
175+ require .Len (t , tokenUsages , expectedTokenRecordings )
175176
176- assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens (recorderClient . tokenUsages ), "input tokens miscalculated" )
177- assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens (recorderClient . tokenUsages ), "output tokens miscalculated" )
177+ assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens (tokenUsages ), "input tokens miscalculated" )
178+ assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens (tokenUsages ), "output tokens miscalculated" )
178179
179- require .Len (t , recorderClient .toolUsages , 1 )
180- assert .Equal (t , "Read" , recorderClient .toolUsages [0 ].Tool )
181- require .IsType (t , json.RawMessage {}, recorderClient .toolUsages [0 ].Args )
180+ toolUsages := recorderClient .RecordedToolUsages ()
181+ require .Len (t , toolUsages , 1 )
182+ assert .Equal (t , "Read" , toolUsages [0 ].Tool )
183+ require .IsType (t , json.RawMessage {}, toolUsages [0 ].Args )
182184 var args map [string ]any
183- require .NoError (t , json .Unmarshal (recorderClient . toolUsages [0 ].Args .(json.RawMessage ), & args ))
185+ require .NoError (t , json .Unmarshal (toolUsages [0 ].Args .(json.RawMessage ), & args ))
184186 require .Contains (t , args , "file_path" )
185187 assert .Equal (t , "/tmp/blah/foo" , args ["file_path" ])
186188
187- require .Len (t , recorderClient .userPrompts , 1 )
188- assert .Equal (t , "read the foo file" , recorderClient .userPrompts [0 ].Prompt )
189+ promptUsages := recorderClient .RecordedPromptUsages ()
190+ require .Len (t , promptUsages , 1 )
191+ assert .Equal (t , "read the foo file" , promptUsages [0 ].Prompt )
189192
190193 recorderClient .verifyAllInterceptionsEnded (t )
191194 })
@@ -346,8 +349,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
346349 // and the interception data.
347350 require .Equal (t , requestCount , 1 )
348351 require .Equal (t , bedrockCfg .Model , receivedModelName )
349- require .Len (t , recorderClient .interceptions , 1 )
350- require .Equal (t , recorderClient .interceptions [0 ].Model , bedrockCfg .Model )
352+ interceptions := recorderClient .RecordedInterceptions ()
353+ require .Len (t , interceptions , 1 )
354+ require .Equal (t , interceptions [0 ].Model , bedrockCfg .Model )
351355 recorderClient .verifyAllInterceptionsEnded (t )
352356 })
353357 }
@@ -437,18 +441,21 @@ func TestOpenAIChatCompletions(t *testing.T) {
437441 assert .Equal (t , "[DONE]" , lastEvent .Data )
438442 }
439443
440- require .Len (t , recorderClient .tokenUsages , 1 )
441- assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens (recorderClient .tokenUsages ), "input tokens miscalculated" )
442- assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens (recorderClient .tokenUsages ), "output tokens miscalculated" )
444+ tokenUsages := recorderClient .RecordedTokenUsages ()
445+ require .Len (t , tokenUsages , 1 )
446+ assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens (tokenUsages ), "input tokens miscalculated" )
447+ assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens (tokenUsages ), "output tokens miscalculated" )
443448
444- require .Len (t , recorderClient .toolUsages , 1 )
445- assert .Equal (t , "read_file" , recorderClient .toolUsages [0 ].Tool )
446- require .IsType (t , map [string ]any {}, recorderClient .toolUsages [0 ].Args )
447- require .Contains (t , recorderClient .toolUsages [0 ].Args , "path" )
448- assert .Equal (t , "README.md" , recorderClient .toolUsages [0 ].Args .(map [string ]any )["path" ])
449+ toolUsages := recorderClient .RecordedToolUsages ()
450+ require .Len (t , toolUsages , 1 )
451+ assert .Equal (t , "read_file" , toolUsages [0 ].Tool )
452+ require .IsType (t , map [string ]any {}, toolUsages [0 ].Args )
453+ require .Contains (t , toolUsages [0 ].Args , "path" )
454+ assert .Equal (t , "README.md" , toolUsages [0 ].Args .(map [string ]any )["path" ])
449455
450- require .Len (t , recorderClient .userPrompts , 1 )
451- assert .Equal (t , "how large is the README.md file in my current path" , recorderClient .userPrompts [0 ].Prompt )
456+ promptUsages := recorderClient .RecordedPromptUsages ()
457+ require .Len (t , promptUsages , 1 )
458+ assert .Equal (t , "how large is the README.md file in my current path" , promptUsages [0 ].Prompt )
452459
453460 recorderClient .verifyAllInterceptionsEnded (t )
454461 })
@@ -605,8 +612,9 @@ func TestSimple(t *testing.T) {
605612 resp .Body = io .NopCloser (bytes .NewReader (bodyBytes ))
606613
607614 // Then: I expect the prompt to have been tracked.
608- require .NotEmpty (t , recorderClient .userPrompts , "no prompts tracked" )
609- assert .Contains (t , recorderClient .userPrompts [0 ].Prompt , "how many angels can dance on the head of a pin" )
615+ promptUsages := recorderClient .RecordedPromptUsages ()
616+ require .NotEmpty (t , promptUsages , "no prompts tracked" )
617+ assert .Contains (t , promptUsages [0 ].Prompt , "how many angels can dance on the head of a pin" )
610618
611619 // Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider.
612620 // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
@@ -615,8 +623,9 @@ func TestSimple(t *testing.T) {
615623 require .NoError (t , err , "failed to retrieve response ID" )
616624 require .Nilf (t , uuid .Validate (id ), "%s is not a valid UUID" , id )
617625
618- require .GreaterOrEqual (t , len (recorderClient .tokenUsages ), 1 )
619- require .Equal (t , recorderClient .tokenUsages [0 ].MsgID , tc .expectedMsgID )
626+ tokenUsages := recorderClient .RecordedTokenUsages ()
627+ require .GreaterOrEqual (t , len (tokenUsages ), 1 )
628+ require .Equal (t , tokenUsages [0 ].MsgID , tc .expectedMsgID )
620629
621630 recorderClient .verifyAllInterceptionsEnded (t )
622631 })
@@ -770,11 +779,12 @@ func TestAnthropicInjectedTools(t *testing.T) {
770779 recorderClient , mcpCalls , _ , resp := setupInjectedToolTest (t , antSingleInjectedTool , streaming , configureFn , createAnthropicMessagesReq )
771780
772781 // Ensure expected tool was invoked with expected input.
773- require .Len (t , recorderClient .toolUsages , 1 )
774- require .Equal (t , mockToolName , recorderClient .toolUsages [0 ].Tool )
782+ toolUsages := recorderClient .RecordedToolUsages ()
783+ require .Len (t , toolUsages , 1 )
784+ require .Equal (t , mockToolName , toolUsages [0 ].Tool )
775785 expected , err := json .Marshal (map [string ]any {"owner" : "admin" })
776786 require .NoError (t , err )
777- actual , err := json .Marshal (recorderClient . toolUsages [0 ].Args )
787+ actual , err := json .Marshal (toolUsages [0 ].Args )
778788 require .NoError (t , err )
779789 require .EqualValues (t , expected , actual )
780790 invocations := mcpCalls .getCallsByTool (mockToolName )
@@ -831,11 +841,13 @@ func TestAnthropicInjectedTools(t *testing.T) {
831841 assert .EqualValues (t , 204 , message .Usage .OutputTokens )
832842
833843 // Ensure tokens used during injected tool invocation are accounted for.
834- assert .EqualValues (t , 15308 , calculateTotalInputTokens (recorderClient .tokenUsages ))
835- assert .EqualValues (t , 204 , calculateTotalOutputTokens (recorderClient .tokenUsages ))
844+ tokenUsages := recorderClient .RecordedTokenUsages ()
845+ assert .EqualValues (t , 15308 , calculateTotalInputTokens (tokenUsages ))
846+ assert .EqualValues (t , 204 , calculateTotalOutputTokens (tokenUsages ))
836847
837848 // Ensure we received exactly one prompt.
838- require .Len (t , recorderClient .userPrompts , 1 )
849+ promptUsages := recorderClient .RecordedPromptUsages ()
850+ require .Len (t , promptUsages , 1 )
839851 })
840852 }
841853}
@@ -857,11 +869,12 @@ func TestOpenAIInjectedTools(t *testing.T) {
857869 recorderClient , mcpCalls , _ , resp := setupInjectedToolTest (t , oaiSingleInjectedTool , streaming , configureFn , createOpenAIChatCompletionsReq )
858870
859871 // Ensure expected tool was invoked with expected input.
860- require .Len (t , recorderClient .toolUsages , 1 )
861- require .Equal (t , mockToolName , recorderClient .toolUsages [0 ].Tool )
872+ toolUsages := recorderClient .RecordedToolUsages ()
873+ require .Len (t , toolUsages , 1 )
874+ require .Equal (t , mockToolName , toolUsages [0 ].Tool )
862875 expected , err := json .Marshal (map [string ]any {"owner" : "admin" })
863876 require .NoError (t , err )
864- actual , err := json .Marshal (recorderClient . toolUsages [0 ].Args )
877+ actual , err := json .Marshal (toolUsages [0 ].Args )
865878 require .NoError (t , err )
866879 require .EqualValues (t , expected , actual )
867880 invocations := mcpCalls .getCallsByTool (mockToolName )
@@ -933,11 +946,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
933946 assert .EqualValues (t , 105 , message .Usage .CompletionTokens )
934947
935948 // Ensure tokens used during injected tool invocation are accounted for.
936- require .EqualValues (t , 5047 , calculateTotalInputTokens (recorderClient .tokenUsages ))
937- require .EqualValues (t , 105 , calculateTotalOutputTokens (recorderClient .tokenUsages ))
949+ tokenUsages := recorderClient .RecordedTokenUsages ()
950+ require .EqualValues (t , 5047 , calculateTotalInputTokens (tokenUsages ))
951+ require .EqualValues (t , 105 , calculateTotalOutputTokens (tokenUsages ))
938952
939953 // Ensure we received exactly one prompt.
940- require .Len (t , recorderClient .userPrompts , 1 )
954+ promptUsages := recorderClient .RecordedPromptUsages ()
955+ require .Len (t , promptUsages , 1 )
941956 })
942957 }
943958}
@@ -1822,6 +1837,40 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
18221837 return nil
18231838}
18241839
1840+ // RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner.
1841+ // Note: This is a shallow clone - the slice is copied but the pointers reference the
1842+ // same underlying records. This is sufficient for our test assertions which only read
1843+ // the data and don't modify the records.
1844+ func (m * mockRecorderClient ) RecordedTokenUsages () []* aibridge.TokenUsageRecord {
1845+ m .mu .Lock ()
1846+ defer m .mu .Unlock ()
1847+ return slices .Clone (m .tokenUsages )
1848+ }
1849+
1850+ // RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner.
1851+ // Note: This is a shallow clone (see RecordedTokenUsages for details).
1852+ func (m * mockRecorderClient ) RecordedPromptUsages () []* aibridge.PromptUsageRecord {
1853+ m .mu .Lock ()
1854+ defer m .mu .Unlock ()
1855+ return slices .Clone (m .userPrompts )
1856+ }
1857+
1858+ // RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner.
1859+ // Note: This is a shallow clone (see RecordedTokenUsages for details).
1860+ func (m * mockRecorderClient ) RecordedToolUsages () []* aibridge.ToolUsageRecord {
1861+ m .mu .Lock ()
1862+ defer m .mu .Unlock ()
1863+ return slices .Clone (m .toolUsages )
1864+ }
1865+
1866+ // RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner.
1867+ // Note: This is a shallow clone (see RecordedTokenUsages for details).
1868+ func (m * mockRecorderClient ) RecordedInterceptions () []* aibridge.InterceptionRecord {
1869+ m .mu .Lock ()
1870+ defer m .mu .Unlock ()
1871+ return slices .Clone (m .interceptions )
1872+ }
1873+
18251874// verify all recorded interceptions has been marked as completed
18261875func (m * mockRecorderClient ) verifyAllInterceptionsEnded (t * testing.T ) {
18271876 t .Helper ()
0 commit comments