Skip to content

Commit

Permalink
wired in more cases of ype_safe_cache.get_origin(tp)
Browse files Browse the repository at this point in the history
  • Loading branch information
DinisCruz committed Jan 21, 2025
1 parent 96e1285 commit 7038b43
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 88 deletions.
8 changes: 5 additions & 3 deletions osbot_utils/type_safe/Type_Safe__Base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import get_origin, get_args, Union, Optional, Any, ForwardRef
from typing import get_args, Union, Optional, Any, ForwardRef

from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache

EXACT_TYPE_MATCH = (int, float, str, bytes, bool, complex)

Expand All @@ -8,7 +10,7 @@ def is_instance_of_type(self, item, expected_type):
return True
if isinstance(expected_type, ForwardRef): # todo: add support for ForwardRef
return True
origin = get_origin(expected_type)
origin = type_safe_cache.get_origin(expected_type)
args = get_args(expected_type)
if origin is None:
if expected_type in EXACT_TYPE_MATCH:
Expand Down Expand Up @@ -90,7 +92,7 @@ def is_instance_of_type(self, item, expected_type):

# todo: see if we should/can move this to the Objects.py file
def type_str(tp):
origin = get_origin(tp)
origin = type_safe_cache.get_origin(tp)
if origin is None:
if hasattr(tp, '__name__'):
return tp.__name__
Expand Down
2 changes: 1 addition & 1 deletion osbot_utils/type_safe/Type_Safe__Dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from osbot_utils.type_safe.Type_Safe__Base import type_str, Type_Safe__Base
from osbot_utils.type_safe.Type_Safe__Base import Type_Safe__Base

class Type_Safe__Dict(Type_Safe__Base, dict):
def __init__(self, expected_key_type, expected_value_type, *args, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions osbot_utils/type_safe/shared/Type_Safe__Annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import get_origin
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache


Expand All @@ -19,7 +18,7 @@ def obj_is_attribute_annotation_of_type(self, target, attr_name, expected_type):
return True
if expected_type is type(attribute_annotation):
return True
if expected_type is get_origin(attribute_annotation): # todo: use get_origin cache # handle genericAlias
if expected_type is type_safe_cache.get_origin(attribute_annotation): # handle genericAlias
return True
return False

Expand Down
7 changes: 5 additions & 2 deletions osbot_utils/type_safe/shared/Type_Safe__Cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def get_class_mro(self, cls):

def get_origin(self, var_type): # Cache expensive get_origin calls
if self.skip_cache or var_type not in self._type__get_origin_cache:
origin = type_safe_not_cached.get_origin(var_type)
self._type__get_origin_cache[var_type] = origin
origin = type_safe_not_cached.get_origin(var_type)
try: # this is needed for the edge case when we can't create a key from the var_type in WeakKeyDictionary (see test test__regression__type_safe_is_not_enforced_on_dict_and_Dict for an example)
self._type__get_origin_cache[var_type] = origin
except TypeError:
pass
self.cache__miss__type__get_origin += 1
else:
origin = self._type__get_origin_cache[var_type]
Expand Down
6 changes: 3 additions & 3 deletions osbot_utils/type_safe/shared/Type_Safe__Convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from osbot_utils.type_safe.steps.Type_Safe__Step__Default_Value import get_origin
from osbot_utils.utils.Objects import base_classes_names
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
from osbot_utils.utils.Objects import base_classes_names


class Type_Safe__Convert:
Expand Down Expand Up @@ -29,7 +29,7 @@ def convert_to_value_from_obj_annotation(self, target, attr_name, value):
if hasattr(obj_annotations,'get'):
attribute_annotation = obj_annotations.get(attr_name)
if attribute_annotation:
origin = get_origin(attribute_annotation) # Add handling for Type[T] annotations
origin = type_safe_cache.get_origin(attribute_annotation) # Add handling for Type[T] annotations
if origin is type and isinstance(value, str):
try: # Convert string path to actual type
if len(value.rsplit('.', 1)) > 1:
Expand Down
3 changes: 3 additions & 0 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def get_cls_kwargs(self, cls : Type )\

if self.is_kwargs_cacheable(cls, kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars)
type_safe_cache.set_cache__cls_kwargs(cls, kwargs) # cache it
# else:
# pass # todo:: see how we can cache more the cases when the data is clean (i.e. default values)
return kwargs

def is_kwargs_cacheable(self, cls, kwargs: Dict[str, Any]) -> bool:
Expand All @@ -42,6 +44,7 @@ def is_kwargs_cacheable(self, cls, kwargs: Dict[str, Any]) -> bool:

if match: # check for special cases that we can't cache (like Random_Guid)
if Random_Guid in list(dict(annotations).values()): # todo: need to add the other special cases (like Timestamp_Now)

return False
return match

Expand Down
28 changes: 11 additions & 17 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Default_Value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,30 @@
import inspect
import typing

from osbot_utils.utils.Objects import default_value
from osbot_utils.type_safe.Type_Safe__List import Type_Safe__List
from osbot_utils.type_safe.Type_Safe__Dict import Type_Safe__Dict
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
from osbot_utils.utils.Objects import default_value
from osbot_utils.type_safe.Type_Safe__List import Type_Safe__List
from osbot_utils.type_safe.Type_Safe__Dict import Type_Safe__Dict


# Backport implementations of get_origin and get_args for Python 3.7 # todo: refactor into separate class (focused on past python version compatibility)
# Backport implementations of get_args for Python 3.7 # todo: refactor into separate class (focused on past python version compatibility)
if sys.version_info < (3, 8): # pragma: no cover
def get_origin(tp):
if isinstance(tp, typing._GenericAlias):
return tp.__origin__
elif tp is typing.Generic:
return typing.Generic
else:
return None

def get_args(tp):
if isinstance(tp, typing._GenericAlias):
return tp.__args__
else:
return ()
else:
from typing import get_origin, get_args, ForwardRef, Any
from osbot_utils.helpers.python_compatibility.python_3_8 import Annotated
from typing import get_args, ForwardRef


class Type_Safe__Step__Default_Value:

def default_value(self, _cls, var_type):

if get_origin(var_type) is type: # Special handling for Type[T] # todo: reuse the get_origin value
origin = type_safe_cache.get_origin(var_type) # todo: refactor this to use the get_origin method
if origin is type: # Special handling for Type[T] # todo: reuse the get_origin value
type_args = get_args(var_type)
if type_args:
if isinstance(type_args[0], ForwardRef):
Expand All @@ -45,13 +39,13 @@ def default_value(self, _cls, var_type):
if var_type is typing.Set: # todo: refactor the dict, set and list logic, since they are 90% the same
return set()

if get_origin(var_type) is set:
if origin is set:
return set() # todo: add Type_Safe__Set

if var_type is typing.Dict:
return {}

if get_origin(var_type) is dict: # e.g. Dict[key_type, value_type]
if origin is dict: # e.g. Dict[key_type, value_type]
key_type, value_type = get_args(var_type)
if isinstance(key_type, ForwardRef): # Handle forward references on key_type ---
forward_name = key_type.__forward_arg__
Expand All @@ -66,7 +60,7 @@ def default_value(self, _cls, var_type):
if var_type is typing.List:
return [] # handle case when List was used with no type information provided

if get_origin(var_type) is list: # if we have list defined as list[type]
if origin is list: # if we have list defined as list[type]
item_type = get_args(var_type)[0] # get the type that was defined
if isinstance(item_type, ForwardRef): # handle the case when the type is a forward reference
forward_name = item_type.__forward_arg__
Expand Down
10 changes: 6 additions & 4 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Set_Attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def resolve_value__int_str(self, _self, name, value):
return type_safe_convert.convert_to_value_from_obj_annotation(_self, name, value)

def resolve_value__from_origin(self, value):
#origin = type_safe_cache.get_origin(value) # todo: figure out why this is the only place that the type_safe_cache.get_origin doesn't work (due to WeakKeyDictionary key error on value)
origin = get_origin(value)

if origin is not None:
value = origin
return value
Expand All @@ -42,16 +44,16 @@ def handle_get_class__annotated(self, annotation, name, value):
attribute.validate(value=value, field_name=name, target_type=target_type)

def handle_get_class__dict(self, _self, name, value):
# todo: refactor how this actually works since it is not good to having to use the deserialize_dict__using_key_value_annotations from here
from osbot_utils.type_safe.steps.Type_Safe__Step__From_Json import Type_Safe__Step__From_Json # here because of circular dependencies
value = Type_Safe__Step__From_Json().deserialize_dict__using_key_value_annotations(_self, name, value)
if value: # todo: see side effects of doing this here (since going into deserialize_dict__using_key_value_annotations has performance hit)
from osbot_utils.type_safe.steps.Type_Safe__Step__From_Json import Type_Safe__Step__From_Json # here because of circular dependencies
value = Type_Safe__Step__From_Json().deserialize_dict__using_key_value_annotations(_self, name, value) # todo: refactor how this actually works since it is not good to having to use the deserialize_dict__using_key_value_annotations from here
return value

def handle_get_class(self, _self, annotations, name, value):
if hasattr(annotations, 'get'):
annotation = annotations.get(name)
if annotation:
annotation_origin = get_origin(annotation)
annotation_origin = type_safe_cache.get_origin(annotation)
if annotation_origin is Annotated:
self.handle_get_class__annotated(annotation, name, value)
elif annotation_origin is dict:
Expand Down
21 changes: 0 additions & 21 deletions osbot_utils/utils/Objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,6 @@
class __(SimpleNamespace):
pass

# Backport implementations of get_origin and get_args for Python 3.7
if sys.version_info < (3, 8):
def get_origin(tp):
import typing
if isinstance(tp, typing._GenericAlias):
return tp.__origin__
elif tp is typing.Generic:
return typing.Generic
else:
return None

def get_args(tp):
import typing
if isinstance(tp, typing._GenericAlias):
return tp.__args__
else:
return ()
else:
from typing import get_origin, get_args, List, Tuple, Dict, Type, _GenericAlias, ForwardRef


def base_classes(cls):
if type(cls) is type:
target = cls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class An_Class_6(Type_Safe):
_.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns , self.time_8_kns )
_.measure(An_Class_1 ).print().assert_time(self.time_100_ns , self.time_200_ns )
_.measure(An_Class_2 ).print().assert_time(self.time_600_ns , self.time_700_ns , self.time_800_ns, self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_4 ).print().assert_time(self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_5 ).print().assert_time(self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns )
Expand Down
39 changes: 21 additions & 18 deletions tests/unit/type_safe/_performance/test__perf__Type_Safe__ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
from unittest import TestCase
from typing import Optional, List, Dict, Union, Any
from enum import Enum, auto

from osbot_utils.testing.performance.Performance_Measure__Session import Performance_Measure__Session
from osbot_utils.type_safe.Type_Safe import Type_Safe
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache


class test__perf__Type_Safe__ctor(TestCase):

@classmethod
def setUpClass(cls): # Set up timing thresholds
import pytest
pytest.skip("skipping until refactoring of Type_Safe is complete")
# import pytest
# pytest.skip("skipping until refactoring of Type_Safe is complete")
cls.time_200_ns = 200
cls.time_300_ns = 300
cls.time_700_ns = 700
Expand Down Expand Up @@ -65,22 +65,25 @@ class SingleDefault(Type_Safe): # Test w

def test_complex_types(self): # Test complex type variations
class ComplexTypes(Type_Safe): # Multiple complex types
optional_str : Optional[str]
str_list : List[str]
int_dict : Dict[str, int]
union_field : Union[str, int]

class NestedType(Type_Safe): # Basic nested type
value: str

class WithNested(Type_Safe): # Complex nesting
nested : NestedType
items : List[NestedType]

optional_str : Optional[str ]
str_list : List [str ]
int_dict : Dict [str, int]
union_field : Union [str, int]

# class NestedType(Type_Safe): # Basic nested type
# value: str
#
# class WithNested(Type_Safe): # Complex nesting
# nested : NestedType
# items : List[NestedType]

print()
with self.session as session:
session.measure(ComplexTypes ).assert_time(self.time_20_kns, self.time_30_kns, self.time_40_kns)
session.measure(NestedType ).assert_time(self.time_2_kns , self.time_3_kns , self.time_20_kns)
session.measure(WithNested ).assert_time(self.time_20_kns, self.time_40_kns)
session.measure(ComplexTypes ).print().assert_time(self.time_20_kns, self.time_30_kns, self.time_40_kns)
# session.measure(NestedType ).print().assert_time(self.time_2_kns , self.time_3_kns , self.time_20_kns)
# session.measure(WithNested ).print().assert_time(self.time_20_kns, self.time_40_kns)

type_safe_cache.print_cache_hits()

def test_inheritance_depth(self): # Test inheritance impact
class Base(Type_Safe): # Base class
Expand Down
34 changes: 18 additions & 16 deletions tests/unit/type_safe/_performance/test__perf__Type_Safe__tracing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, List, Dict, Union
from unittest import TestCase
from osbot_utils.utils.Objects import __

from osbot_utils.helpers.trace.Trace_Call import trace_calls
from osbot_utils.type_safe.Type_Safe import Type_Safe
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
Expand Down Expand Up @@ -32,28 +34,28 @@ class test__perf__Type_Safe__tracing(TestCase):
# An_Class()


@trace_calls(include = ['osbot' ],
ignore = ['typing' ],
show_internals = False ,
show_lines = False ,
show_types = False ,
show_class = True ,
show_duration = True ,
duration_padding = 140 ,
#duration_bigger_than = 0.001
)
# @trace_calls(include = ['osbot' ],
# ignore = ['typing' ],
# show_internals = False ,
# show_lines = False ,
# show_types = False ,
# show_class = True ,
# show_duration = True ,
# duration_padding = 140 ,
# #duration_bigger_than = 0.001
# )
def test_complex_types(self):

class ComplexTypes(Type_Safe): # Multiple complex types
#an_int : int
#optional_str : Optional [str]
#str_list : List [str]
an_int : int
optional_str : Optional [str]
str_list : List [str]
int_dict : Dict [str, int]
#union_field : Union [str, int]
union_field : Union [str, int]

ComplexTypes()
assert ComplexTypes().obj() == __(an_int=0, optional_str=None, str_list=[], int_dict=__(), union_field=None)

type_safe_cache.print_cache_hits()
#type_safe_cache.print_cache_hits()



Expand Down

0 comments on commit 7038b43

Please sign in to comment.