Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions google/genai/_automatic_function_calling_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
#

import inspect
import sys
import types as builtin_types
import typing
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]

import pydantic
Expand All @@ -25,10 +23,7 @@
from . import types


if sys.version_info >= (3, 10):
VersionedUnionType = builtin_types.UnionType
else:
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
_UNION_TYPES = (Union, builtin_types.UnionType)


__all__ = [
Expand All @@ -53,6 +48,10 @@
}


def _is_union_annotation(annotation: inspect.Parameter.annotation) -> bool: # type: ignore[valid-type]
return get_origin(annotation) in _UNION_TYPES # type: ignore[comparison-overlap]


def _raise_for_unsupported_param(
param: inspect.Parameter, func_name: str, exception: Union[Exception, type[Exception]]
) -> None:
Expand Down Expand Up @@ -102,10 +101,10 @@ def _is_default_value_compatible(
if (
isinstance(annotation, _GenericAlias)
or isinstance(annotation, builtin_types.GenericAlias)
or isinstance(annotation, VersionedUnionType)
or _is_union_annotation(annotation)
):
origin = get_origin(annotation)
if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
if origin in _UNION_TYPES: # type: ignore[comparison-overlap]
return any(
_is_default_value_compatible(default_value, arg)
for arg in get_args(annotation)
Expand Down Expand Up @@ -160,7 +159,7 @@ def _parse_schema_from_parameter( # type: ignore[return]
schema.type = _py_builtin_type_to_schema_type[param.annotation]
return schema
if (
isinstance(param.annotation, VersionedUnionType)
_is_union_annotation(param.annotation)
# only parse simple UnionType, example int | str | float | bool
# complex UnionType will be invoked in raise branch
and all(
Expand Down Expand Up @@ -199,8 +198,10 @@ def _parse_schema_from_parameter( # type: ignore[return]
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, builtin_types.GenericAlias
if (
isinstance(param.annotation, _GenericAlias)
or isinstance(param.annotation, builtin_types.GenericAlias)
or _is_union_annotation(param.annotation)
):
origin = get_origin(param.annotation)
args = get_args(param.annotation)
Expand Down Expand Up @@ -239,7 +240,7 @@ def _parse_schema_from_parameter( # type: ignore[return]
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if origin is Union:
if origin in _UNION_TYPES:
schema.any_of = []
schema.type = _py_builtin_type_to_schema_type[dict]
unique_types = set()
Expand Down
7 changes: 1 addition & 6 deletions google/genai/_extra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import inspect
import io
import logging
import sys
import typing
from types import UnionType
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
import mimetypes
import os
Expand All @@ -34,11 +34,6 @@
from ._adapters import McpToGenAiToolAdapter


if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]

if typing.TYPE_CHECKING:
from mcp import ClientSession as McpClientSession
from mcp.types import Tool as McpTool
Expand Down
13 changes: 2 additions & 11 deletions google/genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
import io
import logging
import re
import sys
import time
import types as builtin_types
import typing
from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
from typing import Any, GenericAlias, List, Optional, Sequence, TypeGuard, Union # type: ignore[attr-defined]
from ._mcp_utils import mcp_to_gemini_tool
from ._common import get_value_by_path as getv
from ._common import is_duck_type_of
Expand All @@ -42,14 +41,7 @@

logger = logging.getLogger('google_genai._transformers')

if sys.version_info >= (3, 10):
VersionedUnionType = builtin_types.UnionType
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
from typing import TypeGuard
else:
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
_UNION_TYPES = (typing.Union,)
from typing_extensions import TypeGuard
_UNION_TYPES = (typing.Union, builtin_types.UnionType)

if typing.TYPE_CHECKING:
from mcp import ClientSession as McpClientSession
Expand Down Expand Up @@ -881,7 +873,6 @@ def t_schema(
elif (
isinstance(origin, GenericAlias)
or isinstance(origin, type)
or isinstance(origin, VersionedUnionType)
or typing.get_origin(origin) in _UNION_TYPES
):

Expand Down
25 changes: 16 additions & 9 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import sys
import types as builtin_types
import typing
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
import pydantic
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self, TypedDict
Expand All @@ -38,14 +38,9 @@
)


if sys.version_info >= (3, 10):
# Supports both Union[t1, t2] and t1 | t2
VersionedUnionType = Union[builtin_types.UnionType, _UnionGenericAlias]
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
else:
# Supports only Union[t1, t2]
VersionedUnionType = _UnionGenericAlias
_UNION_TYPES = (typing.Union,)
# Supports both Union[t1, t2] and t1 | t2.
VersionedUnionType = builtin_types.UnionType
_UNION_TYPES = (typing.Union, builtin_types.UnionType)

_is_pillow_image_imported = False
if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -5958,6 +5953,18 @@ class GenerateContentConfig(_common.BaseModel):
@pydantic.field_validator('response_schema', mode='before')
@classmethod
def _convert_literal_to_enum(cls, value: Any) -> Union[Any, EnumMeta]:
# Normalize typing.Union[...] to PEP 604 unions on Python <= 3.13, where
# typing.Union[...] is not an instance of types.UnionType.
if (
typing.get_origin(value) is typing.Union
and not isinstance(value, builtin_types.UnionType)
):
union_args = typing.get_args(value)
if union_args:
normalized_union = union_args[0]
for union_arg in union_args[1:]:
normalized_union = normalized_union | union_arg
return normalized_union
if typing.get_origin(value) is typing.Literal:
enum_vals = typing.get_args(value)
if not all(isinstance(arg, str) for arg in enum_vals):
Expand Down
Loading