88
99from google .genai import types
1010
11+ from lmnr .opentelemetry_lib .decorators import json_dumps
1112from lmnr .opentelemetry_lib .tracing .context import (
1213 get_current_context ,
1314 get_event_attributes_from_context ,
2021from .utils import (
2122 dont_throw ,
2223 get_content ,
24+ process_content_union ,
25+ process_stream_chunk ,
2326 role_from_content_union ,
2427 set_span_attribute ,
25- process_content_union ,
2628 to_dict ,
2729 with_tracer_wrapper ,
2830)
@@ -139,9 +141,7 @@ def _set_request_attributes(span, args, kwargs):
139141 try :
140142 set_span_attribute (
141143 span ,
142- # TODO: change to SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA
143- # when we upgrade to opentelemetry-semantic-conventions-ai>=0.4.10
144- "gen_ai.request.structured_output_schema" ,
144+ SpanAttributes .LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA ,
145145 json .dumps (process_schema (schema ), cls = SchemaJSONEncoder ),
146146 )
147147 except Exception :
@@ -150,10 +150,8 @@ def _set_request_attributes(span, args, kwargs):
150150 try :
151151 set_span_attribute (
152152 span ,
153- # TODO: change to SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA
154- # when we upgrade to opentelemetry-semantic-conventions-ai>=0.4.10
155- "gen_ai.request.structured_output_schema" ,
156- json .dumps (json_schema ),
153+ SpanAttributes .LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA ,
154+ json_dumps (json_schema ),
157155 )
158156 except Exception :
159157 pass
@@ -182,7 +180,7 @@ def _set_request_attributes(span, args, kwargs):
182180 set_span_attribute (
183181 span ,
184182 f"{ SpanAttributes .LLM_REQUEST_FUNCTIONS } .{ tool_num } .parameters" ,
185- json . dumps (tool_dict .get ("parameters" )),
183+ json_dumps (tool_dict .get ("parameters" )),
186184 )
187185
188186 if should_send_prompts ():
@@ -215,7 +213,7 @@ def _set_request_attributes(span, args, kwargs):
215213 (
216214 content_str
217215 if isinstance (content_str , str )
218- else json . dumps (content_str )
216+ else json_dumps (content_str )
219217 ),
220218 )
221219 blocks = (
@@ -248,7 +246,7 @@ def _set_request_attributes(span, args, kwargs):
248246 set_span_attribute (
249247 span ,
250248 f"{ gen_ai_attributes .GEN_AI_PROMPT } .{ i } .tool_calls.{ tool_call_index } .arguments" ,
251- json . dumps (function_call .get ("arguments" )),
249+ json_dumps (function_call .get ("arguments" )),
252250 )
253251 tool_call_index += 1
254252
@@ -300,22 +298,26 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
300298 span , f"{ gen_ai_attributes .GEN_AI_COMPLETION } .0.role" , "model"
301299 )
302300 candidates_list = candidates if isinstance (candidates , list ) else [candidates ]
303- for i , candidate in enumerate (candidates_list ):
301+ i = 0
302+ for candidate in candidates_list :
303+ has_content = False
304304 processed_content = process_content_union (candidate .content )
305305 content_str = get_content (processed_content )
306306
307307 set_span_attribute (
308308 span , f"{ gen_ai_attributes .GEN_AI_COMPLETION } .{ i } .role" , "model"
309309 )
310- set_span_attribute (
311- span ,
312- f"{ gen_ai_attributes .GEN_AI_COMPLETION } .{ i } .content" ,
313- (
314- content_str
315- if isinstance (content_str , str )
316- else json .dumps (content_str )
317- ),
318- )
310+ if content_str :
311+ has_content = True
312+ set_span_attribute (
313+ span ,
314+ f"{ gen_ai_attributes .GEN_AI_COMPLETION } .{ i } .content" ,
315+ (
316+ content_str
317+ if isinstance (content_str , str )
318+ else json_dumps (content_str )
319+ ),
320+ )
319321 blocks = (
320322 processed_content
321323 if isinstance (processed_content , list )
@@ -328,6 +330,7 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
328330 if not block_dict .get ("function_call" ):
329331 continue
330332 function_call = to_dict (block_dict .get ("function_call" , {}))
333+ has_content = True
331334 set_span_attribute (
332335 span ,
333336 f"{ gen_ai_attributes .GEN_AI_COMPLETION } .{ i } .tool_calls.{ tool_call_index } .name" ,
@@ -345,9 +348,11 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
345348 set_span_attribute (
346349 span ,
347350 f"{ gen_ai_attributes .GEN_AI_COMPLETION } .{ i } .tool_calls.{ tool_call_index } .arguments" ,
348- json . dumps (function_call .get ("arguments" )),
351+ json_dumps (function_call .get ("arguments" )),
349352 )
350353 tool_call_index += 1
354+ if has_content :
355+ i += 1
351356
352357
353358@dont_throw
@@ -359,54 +364,45 @@ def _build_from_streaming_response(
359364 aggregated_usage_metadata = defaultdict (int )
360365 model_version = None
361366 for chunk in response :
362- if chunk .model_version :
363- model_version = chunk .model_version
364-
365- if chunk .candidates :
366- # Currently gemini throws an error if you pass more than one candidate
367- # with streaming
368- if chunk .candidates and len (chunk .candidates ) > 0 :
369- final_parts += chunk .candidates [0 ].content .parts or []
370- role = chunk .candidates [0 ].content .role or role
371- if chunk .usage_metadata :
372- usage_dict = to_dict (chunk .usage_metadata )
373- # prompt token count is sent in every chunk
374- # (and is less by 1 in the last chunk, so we set it once);
375- # total token count in every chunk is greater by prompt token count than it should be,
376- # thus this awkward logic here
377- if aggregated_usage_metadata .get ("prompt_token_count" ) is None :
378- # or 0, not .get(key, 0), because sometimes the value is explicitly None
379- aggregated_usage_metadata ["prompt_token_count" ] = (
380- usage_dict .get ("prompt_token_count" ) or 0
381- )
382- aggregated_usage_metadata ["total_token_count" ] = (
383- usage_dict .get ("total_token_count" ) or 0
384- )
385- aggregated_usage_metadata ["candidates_token_count" ] += (
386- usage_dict .get ("candidates_token_count" ) or 0
387- )
388- aggregated_usage_metadata ["total_token_count" ] += (
389- usage_dict .get ("candidates_token_count" ) or 0
390- )
367+ # Important: do all processing in a separate sync function, that is
368+ # wrapped in @dont_throw. If we did it here, the @dont_throw on top of
369+ # this function would not be able to catch the errors, as they are
370+ # raised later, after the generator is returned, and when it is being
371+ # consumed.
372+ chunk_result = process_stream_chunk (
373+ chunk ,
374+ role ,
375+ model_version ,
376+ aggregated_usage_metadata ,
377+ final_parts ,
378+ )
379+ # even though process_stream_chunk can't return None, the result can be
380+ # None, if the processing throws an error (see @dont_throw)
381+ if chunk_result :
382+ role = chunk_result ["role" ]
383+ model_version = chunk_result ["model_version" ]
391384 yield chunk
392385
393- compound_response = types .GenerateContentResponse (
394- candidates = [
395- {
396- "content" : {
397- "parts" : final_parts ,
398- "role" : role ,
399- },
400- }
401- ],
402- usage_metadata = types .GenerateContentResponseUsageMetadataDict (
403- ** aggregated_usage_metadata
404- ),
405- model_version = model_version ,
406- )
407- if span .is_recording ():
408- _set_response_attributes (span , compound_response )
409- span .end ()
386+ try :
387+ compound_response = types .GenerateContentResponse (
388+ candidates = [
389+ {
390+ "content" : {
391+ "parts" : final_parts ,
392+ "role" : role ,
393+ },
394+ }
395+ ],
396+ usage_metadata = types .GenerateContentResponseUsageMetadataDict (
397+ ** aggregated_usage_metadata
398+ ),
399+ model_version = model_version ,
400+ )
401+ if span .is_recording ():
402+ _set_response_attributes (span , compound_response )
403+ finally :
404+ if span .is_recording ():
405+ span .end ()
410406
411407
412408@dont_throw
@@ -418,54 +414,45 @@ async def _abuild_from_streaming_response(
418414 aggregated_usage_metadata = defaultdict (int )
419415 model_version = None
420416 async for chunk in response :
421- if chunk .model_version :
422- model_version = chunk .model_version
423-
424- if chunk .candidates :
425- # Currently gemini throws an error if you pass more than one candidate
426- # with streaming
427- if chunk .candidates and len (chunk .candidates ) > 0 :
428- final_parts += chunk .candidates [0 ].content .parts or []
429- role = chunk .candidates [0 ].content .role or role
430- if chunk .usage_metadata :
431- usage_dict = to_dict (chunk .usage_metadata )
432- # prompt token count is sent in every chunk
433- # (and is less by 1 in the last chunk, so we set it once);
434- # total token count in every chunk is greater by prompt token count than it should be,
435- # thus this awkward logic here
436- if aggregated_usage_metadata .get ("prompt_token_count" ) is None :
437- # or 0, not .get(key, 0), because sometimes the value is explicitly None
438- aggregated_usage_metadata ["prompt_token_count" ] = (
439- usage_dict .get ("prompt_token_count" ) or 0
440- )
441- aggregated_usage_metadata ["total_token_count" ] = (
442- usage_dict .get ("total_token_count" ) or 0
443- )
444- aggregated_usage_metadata ["candidates_token_count" ] += (
445- usage_dict .get ("candidates_token_count" ) or 0
446- )
447- aggregated_usage_metadata ["total_token_count" ] += (
448- usage_dict .get ("candidates_token_count" ) or 0
449- )
417+ # Important: do all processing in a separate sync function, that is
418+ # wrapped in @dont_throw. If we did it here, the @dont_throw on top of
419+ # this function would not be able to catch the errors, as they are
420+ # raised later, after the generator is returned, and when it is being
421+ # consumed.
422+ chunk_result = process_stream_chunk (
423+ chunk ,
424+ role ,
425+ model_version ,
426+ aggregated_usage_metadata ,
427+ final_parts ,
428+ )
429+ # even though process_stream_chunk can't return None, the result can be
430+ # None, if the processing throws an error (see @dont_throw)
431+ if chunk_result :
432+ role = chunk_result ["role" ]
433+ model_version = chunk_result ["model_version" ]
450434 yield chunk
451435
452- compound_response = types .GenerateContentResponse (
453- candidates = [
454- {
455- "content" : {
456- "parts" : final_parts ,
457- "role" : role ,
458- },
459- }
460- ],
461- usage_metadata = types .GenerateContentResponseUsageMetadataDict (
462- ** aggregated_usage_metadata
463- ),
464- model_version = model_version ,
465- )
466- if span .is_recording ():
467- _set_response_attributes (span , compound_response )
468- span .end ()
436+ try :
437+ compound_response = types .GenerateContentResponse (
438+ candidates = [
439+ {
440+ "content" : {
441+ "parts" : final_parts ,
442+ "role" : role ,
443+ },
444+ }
445+ ],
446+ usage_metadata = types .GenerateContentResponseUsageMetadataDict (
447+ ** aggregated_usage_metadata
448+ ),
449+ model_version = model_version ,
450+ )
451+ if span .is_recording ():
452+ _set_response_attributes (span , compound_response )
453+ finally :
454+ if span .is_recording ():
455+ span .end ()
469456
470457
471458@with_tracer_wrapper
@@ -502,7 +489,7 @@ def _wrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
502489 span .record_exception (e , attributes = attributes )
503490 span .set_status (Status (StatusCode .ERROR , str (e )))
504491 span .end ()
505- raise e
492+ raise
506493
507494
508495@with_tracer_wrapper
@@ -541,7 +528,7 @@ async def _awrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
541528 span .record_exception (e , attributes = attributes )
542529 span .set_status (Status (StatusCode .ERROR , str (e )))
543530 span .end ()
544- raise e
531+ raise
545532
546533
547534class GoogleGenAiSdkInstrumentor (BaseInstrumentor ):
0 commit comments