Skip to content

Commit 8730501

Browse files
authored
Merge branch 'main' into fix/pydantic-v2-swagger-ui
2 parents 9572478 + 5d5708b commit 8730501

File tree

9 files changed

+243
-33
lines changed

9 files changed

+243
-33
lines changed

src/google/adk/flows/llm_flows/instructions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import AsyncGenerator
2020
from typing import TYPE_CHECKING
2121

22-
from google.genai import _transformers
2322
from typing_extensions import override
2423

2524
from ...agents.readonly_context import ReadonlyContext
@@ -85,6 +84,8 @@ async def run_async(
8584

8685
# Handle static_instruction - add via append_instructions
8786
if agent.static_instruction:
87+
from google.genai import _transformers
88+
8889
# Convert ContentUnion to Content using genai transformer
8990
static_content = _transformers.t_content(agent.static_instruction)
9091
llm_request.append_instructions(static_content)

src/google/adk/models/apigee_llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@
1818
from functools import cached_property
1919
import logging
2020
import os
21-
import re
2221
from typing import Optional
2322
from typing import TYPE_CHECKING
2423

2524
from google.adk import version as adk_version
26-
from google.genai import Client
2725
from google.genai import types
2826
from typing_extensions import override
2927

3028
from ..utils.env_utils import is_env_enabled
3129
from .google_llm import Gemini
3230

3331
if TYPE_CHECKING:
32+
from google.genai import Client
33+
3434
from .llm_request import LlmRequest
3535

36+
3637
logger = logging.getLogger('google_adk.' + __name__)
3738

3839
_APIGEE_PROXY_URL_ENV_VARIABLE_NAME = 'APIGEE_PROXY_URL'
@@ -137,6 +138,7 @@ def api_client(self) -> Client:
137138
Returns:
138139
The api client.
139140
"""
141+
from google.genai import Client
140142

141143
kwargs_for_http_options = {}
142144
if self._api_version:

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import logging
2222
import time
2323
from typing import Optional
24+
from typing import TYPE_CHECKING
2425

25-
from google.genai import Client
2626
from google.genai import types
2727

2828
from ..utils.feature_decorator import experimental
@@ -32,6 +32,9 @@
3232

3333
logger = logging.getLogger("google_adk." + __name__)
3434

35+
if TYPE_CHECKING:
36+
from google.genai import Client
37+
3538

3639
@experimental
3740
class GeminiContextCacheManager:

src/google/adk/models/gemini_llm_connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import AsyncGenerator
1919
from typing import Union
2020

21-
from google.genai import live
2221
from google.genai import types
2322

2423
from ..utils.context_utils import Aclosing
@@ -28,6 +27,10 @@
2827
logger = logging.getLogger('google_adk.' + __name__)
2928

3029
RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd]
30+
from typing import TYPE_CHECKING
31+
32+
if TYPE_CHECKING:
33+
from google.genai import live
3134

3235

3336
class GeminiLlmConnection(BaseLlmConnection):
@@ -58,6 +61,7 @@ async def send_history(self, history: list[types.Content]):
5861
for content in history
5962
if content.parts and content.parts[0].text
6063
]
64+
logger.debug('Sending history to live connection: %s', contents)
6165

6266
if contents:
6367
await self._gemini_session.send(

src/google/adk/models/google_llm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from typing import TYPE_CHECKING
2828
from typing import Union
2929

30-
from google.genai import Client
3130
from google.genai import types
3231
from typing_extensions import override
3332

@@ -41,6 +40,8 @@
4140
from .llm_response import LlmResponse
4241

4342
if TYPE_CHECKING:
43+
from google.genai import Client
44+
4445
from .llm_request import LlmRequest
4546

4647
logger = logging.getLogger('google_adk.' + __name__)
@@ -200,6 +201,8 @@ def api_client(self) -> Client:
200201
Returns:
201202
The api client.
202203
"""
204+
from google.genai import Client
205+
203206
return Client(
204207
http_options=types.HttpOptions(
205208
headers=self._tracking_headers,
@@ -239,6 +242,8 @@ def _live_api_version(self) -> str:
239242

240243
@cached_property
241244
def _live_api_client(self) -> Client:
245+
from google.genai import Client
246+
242247
return Client(
243248
http_options=types.HttpOptions(
244249
headers=self._tracking_headers, api_version=self._live_api_version
@@ -283,6 +288,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
283288
llm_request.live_connect_config.tools = llm_request.config.tools
284289
logger.info('Connecting to live for model: %s', llm_request.model)
285290
logger.debug('Connecting to live with llm_request:%s', llm_request)
291+
logger.debug('Live connect config: %s', llm_request.live_connect_config)
286292
async with self._live_api_client.aio.live.connect(
287293
model=llm_request.model, config=llm_request.live_connect_config
288294
) as live_session:

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,20 +296,55 @@ def from_function_with_options(
296296
) -> 'types.FunctionDeclaration':
297297

298298
parameters_properties = {}
299-
for name, param in inspect.signature(func).parameters.items():
300-
if param.kind in (
301-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
302-
inspect.Parameter.KEYWORD_ONLY,
303-
inspect.Parameter.POSITIONAL_ONLY,
304-
):
305-
# This snippet catches the case when type hints are stored as strings
306-
if isinstance(param.annotation, str):
307-
param = param.replace(annotation=typing.get_type_hints(func)[name])
308-
309-
schema = _function_parameter_parse_util._parse_schema_from_parameter(
310-
variant, param, func.__name__
311-
)
312-
parameters_properties[name] = schema
299+
parameters_json_schema = {}
300+
annotation_under_future = typing.get_type_hints(func)
301+
try:
302+
for name, param in inspect.signature(func).parameters.items():
303+
if param.kind in (
304+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
305+
inspect.Parameter.KEYWORD_ONLY,
306+
inspect.Parameter.POSITIONAL_ONLY,
307+
):
308+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
309+
param, annotation_under_future, name
310+
)
311+
312+
schema = _function_parameter_parse_util._parse_schema_from_parameter(
313+
variant, param, func.__name__
314+
)
315+
parameters_properties[name] = schema
316+
except ValueError:
317+
# If the function has complex parameter types that fail in _parse_schema_from_parameter,
318+
# we try to generate a json schema for the parameter using pydantic.TypeAdapter.
319+
parameters_properties = {}
320+
for name, param in inspect.signature(func).parameters.items():
321+
if param.kind in (
322+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
323+
inspect.Parameter.KEYWORD_ONLY,
324+
inspect.Parameter.POSITIONAL_ONLY,
325+
):
326+
try:
327+
if param.annotation == inspect.Parameter.empty:
328+
param = param.replace(annotation=Any)
329+
330+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
331+
param, annotation_under_future, name
332+
)
333+
334+
_function_parameter_parse_util._raise_for_invalid_enum_value(param)
335+
336+
json_schema_dict = _function_parameter_parse_util._generate_json_schema_for_parameter(
337+
param
338+
)
339+
340+
parameters_json_schema[name] = types.Schema.model_validate(
341+
json_schema_dict
342+
)
343+
except Exception as e:
344+
_function_parameter_parse_util._raise_for_unsupported_param(
345+
param, func.__name__, e
346+
)
347+
313348
declaration = types.FunctionDeclaration(
314349
name=func.__name__,
315350
description=func.__doc__,
@@ -324,6 +359,12 @@ def from_function_with_options(
324359
declaration.parameters
325360
)
326361
)
362+
elif parameters_json_schema:
363+
declaration.parameters = types.Schema(
364+
type='OBJECT',
365+
properties=parameters_json_schema,
366+
)
367+
327368
if variant == GoogleLLMVariant.GEMINI_API:
328369
return declaration
329370

@@ -372,17 +413,35 @@ def from_function_with_options(
372413
inspect.Parameter.POSITIONAL_OR_KEYWORD,
373414
annotation=return_annotation,
374415
)
375-
# This snippet catches the case when type hints are stored as strings
376416
if isinstance(return_value.annotation, str):
377417
return_value = return_value.replace(
378418
annotation=typing.get_type_hints(func)['return']
379419
)
380420

381-
declaration.response = (
382-
_function_parameter_parse_util._parse_schema_from_parameter(
383-
variant,
384-
return_value,
385-
func.__name__,
421+
response_schema: Optional[types.Schema] = None
422+
response_json_schema: Optional[Union[Dict[str, Any], types.Schema]] = None
423+
try:
424+
response_schema = (
425+
_function_parameter_parse_util._parse_schema_from_parameter(
426+
variant,
427+
return_value,
428+
func.__name__,
429+
)
430+
)
431+
except ValueError:
432+
try:
433+
response_json_schema = (
434+
_function_parameter_parse_util._generate_json_schema_for_parameter(
435+
return_value
436+
)
386437
)
387-
)
438+
response_json_schema = types.Schema.model_validate(response_json_schema)
439+
except Exception as e:
440+
_function_parameter_parse_util._raise_for_unsupported_param(
441+
return_value, func.__name__, e
442+
)
443+
if response_schema:
444+
declaration.response = response_schema
445+
elif response_json_schema:
446+
declaration.response = response_json_schema
388447
return declaration

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,91 @@
4949
logger = logging.getLogger('google_adk.' + __name__)
5050

5151

52+
def _handle_params_as_deferred_annotations(
53+
param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str
54+
) -> inspect.Parameter:
55+
"""Catches the case when type hints are stored as strings."""
56+
if isinstance(param.annotation, str):
57+
param = param.replace(annotation=annotation_under_future[name])
58+
return param
59+
60+
61+
def _add_unevaluated_items_to_fixed_len_tuple_schema(
62+
json_schema: dict[str, Any],
63+
) -> dict[str, Any]:
64+
"""Adds 'unevaluatedItems': False to schemas for fixed-length tuples.
65+
66+
For example, the schema for a parameter of type `tuple[float, float]` would
67+
be:
68+
{
69+
"type": "array",
70+
"prefixItems": [
71+
{
72+
"type": "number"
73+
},
74+
{
75+
"type": "number"
76+
},
77+
],
78+
"minItems": 2,
79+
"maxItems": 2,
80+
"unevaluatedItems": False
81+
}
82+
83+
"""
84+
if (
85+
json_schema.get('maxItems')
86+
and (
87+
json_schema.get('prefixItems')
88+
and len(json_schema['prefixItems']) == json_schema['maxItems']
89+
)
90+
and json_schema.get('type') == 'array'
91+
):
92+
json_schema['unevaluatedItems'] = False
93+
return json_schema
94+
95+
96+
def _raise_for_unsupported_param(
97+
param: inspect.Parameter,
98+
func_name: str,
99+
exception: Exception,
100+
) -> None:
101+
raise ValueError(
102+
f'Failed to parse the parameter {param} of function {func_name} for'
103+
' automatic function calling.Automatic function calling works best with'
104+
' simpler function signature schema, consider manually parsing your'
105+
f' function declaration for function {func_name}.'
106+
) from exception
107+
108+
109+
def _raise_for_invalid_enum_value(param: inspect.Parameter):
110+
"""Raises an error if the default value is not a valid enum value."""
111+
if inspect.isclass(param.annotation) and issubclass(param.annotation, Enum):
112+
if param.default is not inspect.Parameter.empty and param.default not in [
113+
e.value for e in param.annotation
114+
]:
115+
raise ValueError(
116+
f'Default value {param.default} is not a valid enum value for'
117+
f' {param.annotation}.'
118+
)
119+
120+
121+
def _generate_json_schema_for_parameter(
122+
param: inspect.Parameter,
123+
) -> dict[str, Any]:
124+
"""Generates a JSON schema for a parameter using pydantic.TypeAdapter."""
125+
126+
param_schema_adapter = pydantic.TypeAdapter(
127+
param.annotation,
128+
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
129+
)
130+
json_schema_dict = param_schema_adapter.json_schema()
131+
json_schema_dict = _add_unevaluated_items_to_fixed_len_tuple_schema(
132+
json_schema_dict
133+
)
134+
return json_schema_dict
135+
136+
52137
def _is_builtin_primitive_or_compound(
53138
annotation: inspect.Parameter.annotation,
54139
) -> bool:

0 commit comments

Comments
 (0)