Skip to content

Commit

Permalink
in Type_Safe__Cache added cache support for get_origin and ability to…
Browse files Browse the repository at this point in the history
… capture the cache hits

fixed nasty bug with IMMUTABLE_TYPES where we had tuple and frozenset in there (which are not really immutable)
started to refactor get_cls_kwargs into separate methods (so that we can cache the results better)
  • Loading branch information
DinisCruz committed Jan 20, 2025
1 parent c6cf511 commit 27ed391
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 41 deletions.
46 changes: 40 additions & 6 deletions osbot_utils/type_safe/shared/Type_Safe__Cache.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,71 @@
import inspect
from typing import get_origin
from weakref import WeakKeyDictionary


class Type_Safe__Cache:

_annotations_cache : WeakKeyDictionary
_get_origin_cache : WeakKeyDictionary
_mro_cache : WeakKeyDictionary
_valid_vars_cache : WeakKeyDictionary

cache_hit__annotations : int = 0
cache_hit__get_origin : int = 0
cache_hit__mro : int = 0
cache_hit__valid_vars : int = 0
skip_cache : bool = False


# Caching system for Type_Safe methods
def __init__(self):
self._annotations_cache = WeakKeyDictionary() # Cache for class annotations
self._annotations_cache = WeakKeyDictionary() # Cache for class annotations
self._get_origin_cache = WeakKeyDictionary() # Cache for get_origin results
self._mro_cache = WeakKeyDictionary() # Cache for Method Resolution Order
self._valid_vars_cache = WeakKeyDictionary()

def get_class_annotations(self, cls):
if cls not in self._annotations_cache:
self._annotations_cache[cls] = cls.__annotations__.items()
if self.skip_cache or cls not in self._annotations_cache:
self._annotations_cache[cls] = cls.__annotations__.items()
else:
self.cache_hit__annotations += 1
return self._annotations_cache[cls]

def get_class_mro(self, cls):
if cls not in self._mro_cache:
if self.skip_cache or cls not in self._mro_cache:
self._mro_cache[cls] = inspect.getmro(cls)
else:
self.cache_hit__mro += 1
return self._mro_cache[cls]


def get_origin(self, var_type): # Cache expensive get_origin calls
if self.skip_cache or var_type not in self._get_origin_cache:
self._get_origin_cache[var_type] = get_origin(var_type)
else:
self.cache_hit__get_origin += 1
return self._get_origin_cache[var_type]

# todo: see if we have cache misses and invalid hits based on the validator (we might need more validator specific methods)
def get_valid_class_variables(self, cls, validator): # Returns a dictionary of valid class variables that should be processed. Filters out internal variables, methods, and other non-data attributes.
if cls not in self._valid_vars_cache:
def get_valid_class_variables(self, cls, validator):
if self.skip_cache or cls not in self._valid_vars_cache:
valid_variables = {}
for name, value in vars(cls).items():
if not validator(name, value):
valid_variables[name] = value
self._valid_vars_cache[cls] = valid_variables
else:
self.cache_hit__valid_vars += 1
return self._valid_vars_cache[cls]


def print_cache_hits(self):
print()
print("###### Type_Safe_Cache Hits ########")
print()
print(f" annotations : {self.cache_hit__annotations}")
print(f" get_origin : {self.cache_hit__get_origin }")
print(f" mro : {self.cache_hit__mro }")
print(f" valid_vars : {self.cache_hit__valid_vars }")

type_safe_cache = Type_Safe__Cache()
3 changes: 2 additions & 1 deletion osbot_utils/type_safe/shared/Type_Safe__Shared__Variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from enum import EnumMeta


IMMUTABLE_TYPES = (bool, int, float, complex, str, tuple, frozenset, bytes, types.NoneType, EnumMeta, type)
#IMMUTABLE_TYPES = (bool, int, float, complex, str, tuple, frozenset, bytes, types.NoneType, EnumMeta, type)
IMMUTABLE_TYPES = (bool, int, float, complex, str, bytes, types.NoneType, EnumMeta, type)
26 changes: 14 additions & 12 deletions osbot_utils/type_safe/shared/Type_Safe__Validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import types
from enum import EnumMeta
from typing import Any
from typing import Any, Annotated

from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES
from osbot_utils.utils.Objects import obj_is_type_union_compatible
from osbot_utils.type_safe.shared.Type_Safe__Raise_Exception import type_safe_raise_exception
Expand All @@ -9,25 +11,25 @@
class Type_Safe__Validation:

# todo: add cache support to this method
def should_skip_type_check(self, var_type): # Determine if type checking should be skipped
from typing import get_origin, Annotated
return (get_origin(var_type) is Annotated or
get_origin(var_type) is type)
def should_skip_type_check(self, var_type): # Determine if type checking should be skipped
origin = type_safe_cache.get_origin(var_type) # Use cached get_origin
return (origin is Annotated or
origin is type )

def should_skip_var(self, var_name: str, var_value: Any) -> bool: # Determines if variable should be skipped during MRO processing
if var_name.startswith('__'): # skip internal variables
def should_skip_var(self, var_name: str, var_value: Any) -> bool: # Determines if variable should be skipped during MRO processing
if var_name.startswith('__'): # skip internal variables
return True
if isinstance(var_value, types.FunctionType): # skip instance functions
if isinstance(var_value, types.FunctionType): # skip instance functions
return True
if isinstance(var_value, classmethod): # skip class methods
if isinstance(var_value, classmethod): # skip class methods
return True
if isinstance(var_value, property): # skip property descriptors
if isinstance(var_value, property): # skip property descriptors
return True
return False

def validate_type_immutability(self, var_name: str, var_type: Any) -> None: # Validates that type is immutable or in supported format
if var_type not in IMMUTABLE_TYPES and var_name.startswith('__') is False: # if var_type is not one of the IMMUTABLE_TYPES or is an __ internal
if obj_is_type_union_compatible(var_type, IMMUTABLE_TYPES) is False: # if var_type is not something like Optional[Union[int, str]]
if var_type not in IMMUTABLE_TYPES and var_name.startswith('__') is False: # if var_type is not one of the IMMUTABLE_TYPES or is an __ internal
if obj_is_type_union_compatible(var_type, IMMUTABLE_TYPES) is False: # if var_type is not something like Optional[Union[int, str]]
if var_type not in IMMUTABLE_TYPES or type(var_type) not in IMMUTABLE_TYPES:
if not isinstance(var_type, EnumMeta):
type_safe_raise_exception.immutable_type_error(var_name, var_type)
Expand Down
29 changes: 21 additions & 8 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,33 @@ class Type_Safe__Step__Class_Kwargs:
def __init__(self):
self.type_safe_cache = type_safe_cache # Initialize with singleton cache

def get_cls_kwargs(self, cls : Type , # Get class keyword arguments
include_base_classes : bool = True)\
-> Dict[str, Any]:
def get_cls_kwargs(self, cls : Type , # Main entry point for getting class kwargs
include_base_classes : bool = True )\
-> Dict[str, Any]: # Returns dict of class kwargs

if not hasattr(cls, '__mro__'): # Handle non-class inputs
return {}

base_classes = type_safe_cache.get_class_mro(cls) # Get class hierarchy
if not include_base_classes: # Limit to current class if needed
base_classes = base_classes[:1]
if include_base_classes:
return self.get_cls_kwargs__with_inheritance(cls)
return self.get_cls_kwargs__no_inheritance(cls)

def get_cls_kwargs__no_inheritance(self, cls : Type )\
-> Dict[str, Any]:
kwargs = {} # Process current class only
self.process_mro_class (cls, kwargs) # Handle class variables
self.process_annotations(cls, cls, kwargs) # Process type annotations

return kwargs

def get_cls_kwargs__with_inheritance(self, cls : Type )\
-> Dict[str, Any]: # Get class kwargs with inheritance
base_classes = type_safe_cache.get_class_mro(cls)
kwargs = {} # Process inheritance chain
for base_cls in base_classes:
self.process_mro_class (base_cls, kwargs) # Handle MRO class
self.process_annotations(cls, base_cls, kwargs) # Process annotations
self.process_mro_class (base_cls, kwargs) # Handle each class in MRO
self.process_annotations(cls, base_cls, kwargs) # Process its annotations

return kwargs

def handle_undefined_var(self, cls : Type , # Handle undefined class variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class Type_Safe__Step__Default_Value:

def default_value(self, _cls, var_type):


if get_origin(var_type) is type: # Special handling for Type[T] # todo: reuse the get_origin value
type_args = get_args(var_type)
if type_args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from unittest import TestCase
from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache
from osbot_utils.utils.Env import in_github_action
from osbot_utils.helpers.Random_Guid import Random_Guid
from osbot_utils.testing.performance.Performance_Measure__Session import Performance_Measure__Session
Expand Down Expand Up @@ -43,15 +44,31 @@ class An_Class_5(Type_Safe):
class An_Class_6(Type_Safe):
an_str: str = '42'

type_safe_cache.skip_cache = False

Performance_Measure__Session().measure(str ).print().assert_time(self.time_100_ns, self.time_0_ns )
Performance_Measure__Session().measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns )
Performance_Measure__Session().measure(An_Class_1 ).print().assert_time(self.time_100_ns )
Performance_Measure__Session().measure(An_Class_2 ).print().assert_time(self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
Performance_Measure__Session().measure(An_Class_3 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
Performance_Measure__Session().measure(An_Class_4 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
Performance_Measure__Session().measure(An_Class_5 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
Performance_Measure__Session().measure(An_Class_6 ).print().assert_time(self.time_7_kns , self.time_8_kns , self.time_10_kns, self.time_20_kns )
print()
with Performance_Measure__Session(assert_enabled=True) as _:
_.measure(str ).print().assert_time(self.time_100_ns, self.time_0_ns )
_.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns )
_.measure(An_Class_1 ).print().assert_time(self.time_100_ns )
_.measure(An_Class_2 ).print().assert_time(self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_4 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_5 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_7_kns , self.time_8_kns , self.time_10_kns, self.time_20_kns )

type_safe_cache.print_cache_hits()

# with no cache (and other changes) the results were:
#
# str | score: 100 ns | raw: 78 ns
# Random_Guid | score: 6,000 ns | raw: 5,554 ns
# An_Class_1 | score: 100 ns | raw: 131 ns
# An_Class_2 | score: 9,000 ns | raw: 8,882 ns
# An_Class_3 | score: 20,000 ns | raw: 18,633 ns
# An_Class_4 | score: 20,000 ns | raw: 18,801 ns
# An_Class_5 | score: 20,000 ns | raw: 18,588 ns
# An_Class_6 | score: 20,000 ns | raw: 18,965 ns

# def test_dissaembly_both_paths(self):
# from osbot_utils.type_safe.Cache__Class_Kwargs import Cache__Class_Kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class DefaultsClass:
int_val : int = 42
float_val : float = 3.14
bool_val : bool = True
tuple_val : tuple = (1, 2, 3)
frozenset_val: frozenset = frozenset([1, 2, 3])
#tuple_val : tuple = (1, 2, 3)
#frozenset_val: frozenset = frozenset([1, 2, 3])
bytes_val : bytes = b"bytes"
enum_val : SimpleEnum = SimpleEnum.A

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/type_safe/test_Type_Safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class Immutable_Types_Class(Type_Safe):
a_int : int = 1
a_float : float = 1.0
a_str : str = "string"
a_tuple : tuple = (1, 2)
a_frozenset : frozenset = frozenset([1, 2])
#a_tuple : tuple = (1, 2)
#a_frozenset : frozenset = frozenset([1, 2])
a_bytes : bytes = b"byte"

class With_Optional_And_Union(Type_Safe):
Expand All @@ -99,7 +99,7 @@ class With_Optional_And_Union(Type_Safe):

immutable_types_class = Immutable_Types_Class()
with_optional_and_union = With_Optional_And_Union()
assert immutable_types_class .__locals__() == {'a_int': 1, 'a_float': 1.0, 'a_str': 'string', 'a_tuple': (1, 2), 'a_frozenset': frozenset({1, 2}), 'a_bytes': b'byte'}
assert immutable_types_class .__locals__() == {'a_int': 1, 'a_float': 1.0, 'a_str': 'string', 'a_bytes': b'byte'}
assert with_optional_and_union.__locals__() == {'optional_int': None, 'union_str_float': 'string_or_float', 'union_with_none': None}

def test___default_kwargs__(self):
Expand Down

0 comments on commit 27ed391

Please sign in to comment.