Skip to content

Commit

Permalink
pg.typing to support typing.Generic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550112605
  • Loading branch information
daiyip authored and pyglove authors committed Jul 22, 2023
1 parent 9e2a772 commit 4eabef3
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 15 deletions.
8 changes: 8 additions & 0 deletions pyglove/core/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,14 @@ class Foo(pg.Object):
from pyglove.core.typing.type_conversion import get_first_applicable_converter
from pyglove.core.typing.type_conversion import get_json_value_converter

# Generic helpers.
from pyglove.core.typing.generic import is_subclass
from pyglove.core.typing.generic import is_instance
from pyglove.core.typing.generic import get_type
from pyglove.core.typing.generic import get_type_args
from pyglove.core.typing.generic import is_generic
from pyglove.core.typing.generic import has_generic_bases

# Annotation conversion.
import pyglove.core.typing.annotation_conversion

Expand Down
14 changes: 12 additions & 2 deletions pyglove/core/typing/annotation_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pyglove.core import object_utils
from pyglove.core.typing import class_schema
from pyglove.core.typing import generic
from pyglove.core.typing import key_specs as ks
from pyglove.core.typing import value_specs as vs

Expand Down Expand Up @@ -149,6 +150,12 @@ def _value_spec_from_type_annotation(
return_spec = _value_spec_from_type_annotation(
args[1], accept_value_as_annotation=False)
return vs.Callable(arg_specs, returns=return_spec)
# Handling type
elif origin is type or (annotation in (typing.Type, type)):
if not args:
return vs.Type(typing.Any)
assert len(args) == 1, (annotation, args)
return vs.Type(args[0])
# Handling union.
elif origin is typing.Union or (_UnionType and origin is _UnionType):
optional = _NoneType in args
Expand All @@ -162,8 +169,11 @@ def _value_spec_from_type_annotation(
spec = spec.noneable()
return spec
# Handling class.
elif (inspect.isclass(annotation)
or (isinstance(annotation, str) and not accept_value_as_annotation)):
elif (
inspect.isclass(annotation)
or generic.is_generic(annotation)
or (isinstance(annotation, str) and not accept_value_as_annotation)
):
return vs.Object(annotation)

if accept_value_as_annotation:
Expand Down
24 changes: 24 additions & 0 deletions pyglove/core/typing/annotation_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,30 @@ class Foo:
self.assertEqual(
ValueSpec.from_annotation(Foo, False), vs.Any(annotation=Foo))

def test_generic_class(self):
X = typing.TypeVar('X')
Y = typing.TypeVar('Y')

class Foo(typing.Generic[X, Y]):
pass

self.assertEqual(
ValueSpec.from_annotation(Foo[int, str], True), vs.Object(Foo[int, str])
)

def test_type(self):
class Foo:
pass

self.assertEqual(
ValueSpec.from_annotation(typing.Type[Foo], True), vs.Type(Foo)
)
self.assertEqual(ValueSpec.from_annotation(type[Foo], True), vs.Type(Foo))
self.assertEqual(
ValueSpec.from_annotation(typing.Type, True), vs.Type(typing.Any)
)
self.assertEqual(ValueSpec.from_annotation(type, True), vs.Type(typing.Any))

def test_optional(self):
self.assertEqual(
ValueSpec.from_annotation(typing.Optional[int], True),
Expand Down
115 changes: 115 additions & 0 deletions pyglove/core/typing/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2023 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility module for inspecting generics types."""

import typing
from typing import Any, Optional, Tuple, Type, Union


def is_instance(value: Any, target: Union[Type[Any], Tuple[Type[Any]]]) -> bool:
"""An isinstance extension that supports Any and generic types."""
return is_subclass(type(value), target)


def is_subclass(
src: Type[Any], target: Union[Type[Any], Tuple[Type[Any]]]
) -> bool:
"""An issubclass extension that supports Any and generic types."""

def _is_subclass(src: Type[Any], target: Type[Any]) -> bool:
if target is Any:
return True
elif src is Any:
return False

orig_target = typing.get_origin(target)
orig_src = typing.get_origin(src)

if orig_target is None:
if orig_src is None:
# Both soure and target is not a generic class.
return issubclass(src, target)
# Source class is generic but not the target class.
return issubclass(orig_src, target)
elif orig_src is None:
# Target class is generic but not the source class.
if not issubclass(src, orig_target):
return False
elif not issubclass(orig_src, orig_target):
# Both are generic, but the source is not a subclass of the target.
return False

# Check type args.
t_args = get_type_args(target)
if not t_args:
return True

s_args = get_type_args(src, base=orig_target)
if s_args:
assert len(s_args) == len(t_args), (s_args, t_args)
for s_arg, t_arg in zip(s_args, t_args):
if not _is_subclass(s_arg, t_arg):
return False
return True
else:
# A class could inherit multiple generic types. However it does not
# provide the type arguments for the target generic base. E.g.
#
# class A(Generic[X, Y]):
# class B(A, Generic[X, Y]) :
# B[int, int] is not a subclass of A[int, int].
return False

if isinstance(target, tuple):
return any(_is_subclass(src, x) for x in target)
return _is_subclass(src, target)


def is_generic(maybe_generic: Type[Any]) -> bool:
"""Returns True if a type is a generic class."""
return typing.get_origin(maybe_generic) is not None


def has_generic_bases(maybe_generic: Type[Any]) -> bool:
"""Returns True if a type is a generic subclass."""
return bool(getattr(maybe_generic, '__orig_bases__', None))


def get_type(maybe_type: Any) -> Type[Any]:
"""Gets the type of a maybe generic type."""
if isinstance(maybe_type, type):
return maybe_type
origin = typing.get_origin(maybe_type)
if origin is not None:
return origin
else:
raise TypeError(f'{maybe_type!r} is not a type.')


def get_type_args(
maybe_generic: Type[Any], base: Optional[Type[Any]] = None
) -> Tuple[Type[Any], ...]:
"""Gets generic type args conditioned on an optional base class."""
if base is None:
return typing.get_args(maybe_generic)
else:
orig_cls = typing.get_origin(maybe_generic)
if orig_cls is not None:
orig_bases = (maybe_generic,)
else:
orig_bases = getattr(maybe_generic, '__orig_bases__', ())
for orig_base in orig_bases:
if get_type(orig_base) is base:
return typing.get_args(orig_base)
return ()
146 changes: 146 additions & 0 deletions pyglove/core/typing/generic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2023 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for generic type utility."""

from typing import Any, Generic, TypeVar
import unittest

from pyglove.core.typing import generic

XType = TypeVar('XType')
YType = TypeVar('YType')


class Str(str):
pass


class A(Generic[XType, YType]):
pass


class B(Generic[XType]):
pass


class B1(B, Generic[YType]):
pass


class B2(B1[int]):
pass


class C(A[str, int], B[Str]):
pass


class D(C):
pass


class GenericTest(unittest.TestCase):

def test_issubclass(self):
# Any.
self.assertTrue(generic.is_subclass(int, Any))
self.assertFalse(generic.is_subclass(Any, int))

# Non-generic vs. non-generic.
self.assertTrue(generic.is_subclass(int, object))
self.assertTrue(generic.is_subclass(int, int))
self.assertTrue(generic.is_subclass(Str, str))
self.assertFalse(generic.is_subclass(str, Str))
self.assertTrue(generic.is_subclass(B1, B))
self.assertFalse(generic.is_subclass(B, B1))
self.assertTrue(generic.is_subclass(B1, Generic))
self.assertTrue(generic.is_subclass(C, C))
self.assertTrue(generic.is_subclass(D, C))
self.assertFalse(generic.is_subclass(C, D))

# Non-generic vs. generic.
self.assertTrue(generic.is_subclass(C, A[str, int]))
self.assertFalse(generic.is_subclass(C, A[int, int]))
self.assertFalse(generic.is_subclass(Str, A[int, int]))

self.assertTrue(generic.is_subclass(C, B[Str]))
self.assertTrue(generic.is_subclass(C, B[str]))
self.assertFalse(generic.is_subclass(C, B[int]))
self.assertTrue(generic.is_subclass(D, B[Str]))

# B1 is a subclass of B without type args.
self.assertFalse(generic.is_subclass(B1, B[Any]))

# Non-generic vs. generic type.
self.assertTrue(generic.is_subclass(C, A[str, int]))
self.assertTrue(generic.is_subclass(C, A[Any, Any]))
self.assertTrue(generic.is_subclass(C, B[Str]))
self.assertTrue(generic.is_subclass(C, B[str]))

# Generic vs. non-generic type:
self.assertTrue(generic.is_subclass(B[Str], B))
self.assertTrue(generic.is_subclass(B1, B))
self.assertTrue(generic.is_subclass(B1[Any], B))

# Generic type vs. generic type.
self.assertTrue(generic.is_subclass(B[Str], B[str]))
self.assertTrue(generic.is_subclass(B[str], B[str]))
self.assertFalse(generic.is_subclass(B1[Str], B[str]))
self.assertFalse(generic.is_subclass(B1[Str], A[str, int]))

# Test tuple cases.
self.assertTrue(generic.is_subclass(int, (str, int)))
self.assertTrue(generic.is_subclass(C, (int, A[str, int])))

def test_isinstance(self):
self.assertTrue(generic.is_instance('abc', str))
self.assertTrue(generic.is_instance('abc', Any))
self.assertTrue(generic.is_instance('abc', (int, str)))

self.assertTrue(generic.is_instance(D(), Any))
self.assertTrue(generic.is_instance(D(), A[str, int]))
self.assertTrue(generic.is_instance(D(), B[str]))

def test_is_generic(self):
self.assertFalse(generic.is_generic(str))
self.assertFalse(generic.is_generic(Any))
self.assertFalse(generic.is_generic(A))
self.assertTrue(generic.is_generic(A[str, int]))

def test_has_generic_bases(self):
self.assertFalse(generic.has_generic_bases(str))
self.assertFalse(generic.has_generic_bases(Any))
self.assertTrue(generic.has_generic_bases(A))
self.assertTrue(generic.has_generic_bases(C))

def test_get_type(self):
self.assertIs(generic.get_type(str), str)
self.assertIs(generic.get_type(A), A)
self.assertIs(generic.get_type(A[str, int]), A)
with self.assertRaisesRegex(TypeError, '.* is not a type.'):
generic.get_type(Any)

def test_get_type_args(self):
self.assertEqual(generic.get_type_args(str), ())
self.assertEqual(generic.get_type_args(A), ())
self.assertEqual(generic.get_type_args(A[str, int]), (str, int))
self.assertEqual(generic.get_type_args(B1, A), ())
self.assertEqual(generic.get_type_args(B1[str], A), ())
self.assertEqual(generic.get_type_args(C), ())
self.assertEqual(generic.get_type_args(C, A), (str, int))
self.assertEqual(generic.get_type_args(C, B), (Str,))


if __name__ == '__main__':
unittest.main()
20 changes: 13 additions & 7 deletions pyglove/core/typing/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, Optional, Tuple, Type, Union

from pyglove.core import object_utils
from pyglove.core.typing import generic


class _TypeConverterRegistry:
Expand All @@ -35,8 +36,12 @@ def register(
dest: Union[Type[Any], Tuple[Type[Any], ...]],
convert_fn: Callable[[Any], Any]) -> None: # pyformat: disable pylint: disable=line-too-long
"""Register a converter from src type to dest type."""
if (not isinstance(src, (tuple, type)) or
not isinstance(dest, (tuple, type))):
if (
not isinstance(src, (tuple, type))
and not generic.is_generic(src)
or not isinstance(dest, (tuple, type))
and not generic.is_generic(dest)
):
raise TypeError('Argument \'src\' and \'dest\' must be a type or '
'tuple of types.')
if isinstance(dest, tuple):
Expand All @@ -54,17 +59,19 @@ def get_converter(
# NOTE(daiyip): We do reverse lookup since usually subclass converter
# is register after base class.
for src_type, dest_type, converter, _ in reversed(self._converter_list):
dest_type = dest_type if isinstance(dest_type, tuple) else (dest_type,)
if issubclass(src, src_type) and dest in dest_type:
return converter
if generic.is_subclass(src, src_type):
dest_types = dest_type if isinstance(dest_type, tuple) else (dest_type,)
for dest_type in dest_types:
if generic.is_subclass(dest_type, dest):
return converter
return None

def get_json_value_converter(
self, src: Type[Any]) -> Optional[Callable[[Any], Any]]:
"""Get converter from source type to a JSON simple type."""
for src_type, _, converter, json_value_convertible in reversed(
self._converter_list):
if issubclass(src, src_type) and json_value_convertible:
if generic.is_subclass(src, src_type) and json_value_convertible:
return converter
return None

Expand Down Expand Up @@ -139,4 +146,3 @@ def _register_builtin_converters():

_register_builtin_converters()
object_utils.JSONConvertible.TYPE_CONVERTER = get_json_value_converter

Loading

0 comments on commit 4eabef3

Please sign in to comment.