diff --git a/google/genai/_automatic_function_calling_util.py b/google/genai/_automatic_function_calling_util.py index ec7f9a702..277417e72 100644 --- a/google/genai/_automatic_function_calling_util.py +++ b/google/genai/_automatic_function_calling_util.py @@ -14,9 +14,7 @@ # import inspect -import sys import types as builtin_types -import typing from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined] import pydantic @@ -25,10 +23,7 @@ from . import types -if sys.version_info >= (3, 10): - VersionedUnionType = builtin_types.UnionType -else: - VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined] +_UNION_TYPES = (Union, builtin_types.UnionType) __all__ = [ @@ -53,6 +48,10 @@ } +def _is_union_annotation(annotation: inspect.Parameter.annotation) -> bool: # type: ignore[valid-type] + return get_origin(annotation) in _UNION_TYPES # type: ignore[comparison-overlap] + + def _raise_for_unsupported_param( param: inspect.Parameter, func_name: str, exception: Union[Exception, type[Exception]] ) -> None: @@ -102,10 +101,10 @@ def _is_default_value_compatible( if ( isinstance(annotation, _GenericAlias) or isinstance(annotation, builtin_types.GenericAlias) - or isinstance(annotation, VersionedUnionType) + or _is_union_annotation(annotation) ): origin = get_origin(annotation) - if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap] + if origin in _UNION_TYPES: # type: ignore[comparison-overlap] return any( _is_default_value_compatible(default_value, arg) for arg in get_args(annotation) @@ -160,7 +159,7 @@ def _parse_schema_from_parameter( # type: ignore[return] schema.type = _py_builtin_type_to_schema_type[param.annotation] return schema if ( - isinstance(param.annotation, VersionedUnionType) + _is_union_annotation(param.annotation) # only parse simple UnionType, example int | str | float | bool # complex UnionType will be invoked in raise branch and all( @@ -199,8 +198,10 @@ def _parse_schema_from_parameter( # type: ignore[return] raise ValueError(default_value_error_msg) schema.default = param.default return schema - if isinstance(param.annotation, _GenericAlias) or isinstance( - param.annotation, builtin_types.GenericAlias + if ( + isinstance(param.annotation, _GenericAlias) + or isinstance(param.annotation, builtin_types.GenericAlias) + or _is_union_annotation(param.annotation) ): origin = get_origin(param.annotation) args = get_args(param.annotation) @@ -239,7 +240,7 @@ def _parse_schema_from_parameter( # type: ignore[return] raise ValueError(default_value_error_msg) schema.default = param.default return schema - if origin is Union: + if origin in _UNION_TYPES: schema.any_of = [] schema.type = _py_builtin_type_to_schema_type[dict] unique_types = set() diff --git a/google/genai/_extra_utils.py b/google/genai/_extra_utils.py index 129c05f7d..394aece88 100644 --- a/google/genai/_extra_utils.py +++ b/google/genai/_extra_utils.py @@ -19,8 +19,8 @@ import inspect import io import logging -import sys import typing +from types import UnionType from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin import mimetypes import os @@ -34,11 +34,6 @@ from ._adapters import McpToGenAiToolAdapter -if sys.version_info >= (3, 10): - from types import UnionType -else: - UnionType = typing._UnionGenericAlias # type: ignore[attr-defined] - if typing.TYPE_CHECKING: from mcp import ClientSession as McpClientSession from mcp.types import Tool as McpTool diff --git a/google/genai/_transformers.py b/google/genai/_transformers.py index 0e1a9c41c..de5f1aee8 100644 --- a/google/genai/_transformers.py +++ b/google/genai/_transformers.py @@ -22,11 +22,10 @@ import io import logging import re -import sys import time import types as builtin_types import typing -from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined] +from typing import Any, GenericAlias, List, Optional, Sequence, TypeGuard, Union # type: ignore[attr-defined] from ._mcp_utils import mcp_to_gemini_tool from ._common import get_value_by_path as getv from ._common import is_duck_type_of @@ -42,14 +41,7 @@ logger = logging.getLogger('google_genai._transformers') -if sys.version_info >= (3, 10): - VersionedUnionType = builtin_types.UnionType - _UNION_TYPES = (typing.Union, builtin_types.UnionType) - from typing import TypeGuard -else: - VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined] - _UNION_TYPES = (typing.Union,) - from typing_extensions import TypeGuard +_UNION_TYPES = (typing.Union, builtin_types.UnionType) if typing.TYPE_CHECKING: from mcp import ClientSession as McpClientSession @@ -881,7 +873,6 @@ def t_schema( elif ( isinstance(origin, GenericAlias) or isinstance(origin, type) - or isinstance(origin, VersionedUnionType) or typing.get_origin(origin) in _UNION_TYPES ): diff --git a/google/genai/types.py b/google/genai/types.py index 4bc37eb98..93a1df902 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -25,7 +25,7 @@ import sys import types as builtin_types import typing -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union import pydantic from pydantic import ConfigDict, Field, PrivateAttr, model_validator from typing_extensions import Self, TypedDict @@ -38,14 +38,9 @@ ) -if sys.version_info >= (3, 10): - # Supports both Union[t1, t2] and t1 | t2 - VersionedUnionType = Union[builtin_types.UnionType, _UnionGenericAlias] - _UNION_TYPES = (typing.Union, builtin_types.UnionType) -else: - # Supports only Union[t1, t2] - VersionedUnionType = _UnionGenericAlias - _UNION_TYPES = (typing.Union,) +# Supports both Union[t1, t2] and t1 | t2. +VersionedUnionType = builtin_types.UnionType +_UNION_TYPES = (typing.Union, builtin_types.UnionType) _is_pillow_image_imported = False if typing.TYPE_CHECKING: @@ -5958,6 +5953,18 @@ class GenerateContentConfig(_common.BaseModel): @pydantic.field_validator('response_schema', mode='before') @classmethod def _convert_literal_to_enum(cls, value: Any) -> Union[Any, EnumMeta]: + # Normalize typing.Union[...] to PEP 604 unions on Python <= 3.13, where + # typing.Union[...] is not an instance of types.UnionType. + if ( + typing.get_origin(value) is typing.Union + and not isinstance(value, builtin_types.UnionType) + ): + union_args = typing.get_args(value) + if union_args: + normalized_union = union_args[0] + for union_arg in union_args[1:]: + normalized_union = normalized_union | union_arg + return normalized_union if typing.get_origin(value) is typing.Literal: enum_vals = typing.get_args(value) if not all(isinstance(arg, str) for arg in enum_vals):