Skip to content

Commit

Permalink
fixed couple nasty bugs in the way the attributes are calculated and …
Browse files Browse the repository at this point in the history
…cached

removed the funcionality to include_base_classes in Type_Safe __cls__kwargs__ since that was not being used in practice
  • Loading branch information
DinisCruz committed Jan 20, 2025
1 parent 4dd6d70 commit 09d02d0
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 42 deletions.
4 changes: 2 additions & 2 deletions osbot_utils/type_safe/Type_Safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __attr_names__(self):
return list_set(self.__locals__())

@classmethod
def __cls_kwargs__(cls, include_base_classes=True): # Return current class dictionary of class level variables and their values
return type_safe_step_class_kwargs.get_cls_kwargs(cls, include_base_classes)
def __cls_kwargs__(cls): # Return current class dictionary of class level variables and their values
return type_safe_step_class_kwargs.get_cls_kwargs(cls)

@classmethod
def __default__value__(cls, var_type):
Expand Down
3 changes: 2 additions & 1 deletion osbot_utils/type_safe/shared/Type_Safe__Cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import get_origin
from weakref import WeakKeyDictionary
from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES
from osbot_utils.utils.Objects import all_annotations__in_class


class Type_Safe__Cache:
Expand Down Expand Up @@ -42,7 +43,7 @@ def get_class_annotations(self, cls):
annotations = self._cls__annotations_cache.get(cls) # this is a more efficient cache retrieval pattern (we only get the data from the dict once)
if not annotations: # todo: apply this to the other cache getters
if self.skip_cache or cls not in self._cls__annotations_cache:
annotations = cls.__annotations__.items()
annotations = all_annotations__in_class(cls).items()
self._cls__annotations_cache[cls] = annotations
else:
self.cache_hit__cls__annotations += 1
Expand Down
35 changes: 16 additions & 19 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict, Any, Type

from osbot_utils.helpers.Random_Guid import Random_Guid

from osbot_utils.type_safe.shared.Type_Safe__Cache import Type_Safe__Cache, type_safe_cache
from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES
from osbot_utils.type_safe.shared.Type_Safe__Validation import type_safe_validation
Expand All @@ -13,27 +16,14 @@ class Type_Safe__Step__Class_Kwargs:
def __init__(self):
self.type_safe_cache = type_safe_cache # Initialize with singleton cache

def get_cls_kwargs(self, cls : Type , # Main entry point for getting class kwargs
include_base_classes : bool = True )\
-> Dict[str, Any]: # Returns dict of class kwargs
def get_cls_kwargs(self, cls : Type )\
-> Dict[str, Any]: # Main entry point for getting class kwargs, returns dict of class kwargs

if not hasattr(cls, '__mro__'): # Handle non-class inputs
return {}

if include_base_classes:
return self.get_cls_kwargs__with_inheritance(cls)
return self.get_cls_kwargs__no_inheritance(cls)

def get_cls_kwargs__no_inheritance(self, cls : Type )\
-> Dict[str, Any]: # note: this method is mainly used by tests, so we don't need to add cache support here
kwargs = {} # Process current class only
self.process_mro_class (cls, kwargs) # Handle class variables
self.process_annotations(cls, cls, kwargs) # Process type annotations
return kwargs

def get_cls_kwargs__with_inheritance(self, cls : Type )\
-> Dict[str, Any]: # Get class kwargs with inheritance
kwargs = type_safe_cache.get_cls_kwargs(cls) # see if we have cached data for this class

if kwargs is not None:
return kwargs
else:
Expand All @@ -44,12 +34,19 @@ def get_cls_kwargs__with_inheritance(self, cls : Type )\
self.process_mro_class (base_cls, kwargs) # Handle each class in MRO
self.process_annotations(cls, base_cls, kwargs) # Process its annotations

if self.is_kwargs_cacheable(kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars)
if self.is_kwargs_cacheable(cls, kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars)
type_safe_cache.set_cache__cls_kwargs(cls, kwargs) # cache it
return kwargs

def is_kwargs_cacheable(self, kwargs: Dict[str, Any]) -> bool:
return all(isinstance(value, IMMUTABLE_TYPES) for value in kwargs.values())
def is_kwargs_cacheable(self, cls, kwargs: Dict[str, Any]) -> bool:
annotations = type_safe_cache.get_class_annotations(cls)
match = all(isinstance(value, IMMUTABLE_TYPES) for value in kwargs.values())

if match: # check for special cases that we can't cache (like Random_Guid)
if Random_Guid in list(dict(annotations).values()): # todo: need to add the other special cases (like Timestamp_Now)
return False
return match


def handle_undefined_var(self, cls : Type , # Handle undefined class variables
kwargs : Dict[str, Any] ,
Expand Down
8 changes: 8 additions & 0 deletions osbot_utils/utils/Objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ def all_annotations(target):
annotations.update(base.__annotations__)
return annotations

def all_annotations__in_class(target):
annotations = {}
if hasattr(target, '__mro__'):
for base in reversed(target.__mro__):
if hasattr(base, '__annotations__'):
annotations.update(base.__annotations__)
return annotations

def value_type_matches_obj_annotation_for_attr(target, attr_name, value):
import typing
annotations = all_annotations(target)
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/helpers/trace/test_Trace_Files.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def test___default_kwargs__(self):
assert trace_files.stack[0] == Trace_Call__Stack_Node(name=DEFAULT_ROOT_NODE_NODE_TITLE)

def test___init__(self):
assert Trace_Files.__cls_kwargs__(include_base_classes=False) == {'files': []}
trace_files__cls_kwargs = Trace_Files.__cls_kwargs__()
assert trace_files__cls_kwargs == {'config' : trace_files__cls_kwargs.get('config'),
'files' : [] ,
'prev_trace_function': None ,
'started' : False }
assert Trace_Files().files == []

assert Trace_Files(files=[] ).files == []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUpClass(cls):
cls.time_20_kns = 20_000

def test_measure(self):
class An_Class_1():
class An_Class_1:
pass

class An_Class_2(Type_Safe):
Expand All @@ -52,13 +52,13 @@ class An_Class_6(Type_Safe):
print()
with Performance_Measure__Session(assert_enabled=True) as _:
_.measure(str ).print().assert_time(self.time_100_ns, self.time_0_ns )
_.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns )
_.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns , self.time_8_kns )
_.measure(An_Class_1 ).print().assert_time(self.time_100_ns, self.time_200_ns )
_.measure(An_Class_2 ).print().assert_time(self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_4 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_5 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_4_kns , self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_4 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_5 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns )

type_safe_cache.print_cache_hits()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,9 @@ class ChildClass(BaseClass):
child_int : int = 84

def get_inherited_kwargs_with_base(): # Get kwargs including base class
return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass, include_base_classes=True)

def get_inherited_kwargs_no_base(): # Get kwargs excluding base class
return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass, include_base_classes=False)
return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass)

self.session.measure(get_inherited_kwargs_with_base).assert_time(self.time_6_kns)
self.session.measure(get_inherited_kwargs_no_base ).assert_time(self.time_3_kns)

def test_with_methods(self): # Test with instance and class methods
class MethodClass:
Expand Down
14 changes: 6 additions & 8 deletions tests/unit/type_safe/test_Type_Safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from osbot_utils.type_safe.steps.Type_Safe__Step__From_Json import type_safe_step_from_json
from osbot_utils.utils.Json import json_dumps
from osbot_utils.utils.Misc import random_string, list_set
from osbot_utils.utils.Objects import obj_data, __, default_value, serialize_to_dict
from osbot_utils.utils.Objects import obj_data, __, default_value, serialize_to_dict


class test_Type_Safe(TestCase):
Expand Down Expand Up @@ -44,14 +44,12 @@ def test___cls_kwargs__(self):
if sys.version_info < (3, 9):
pytest.skip("Skipping test that doesn't work on 3.8 or lower")

assert self.Config_Class.__cls_kwargs__(include_base_classes=False) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print }
assert self.Config_Class.__cls_kwargs__(include_base_classes=True ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print }
assert self.Extra_Config.__cls_kwargs__(include_base_classes=False) == {'attribute3': 'another_value', 'callable_attr_2': print }
assert self.Extra_Config.__cls_kwargs__(include_base_classes=True ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print ,
'attribute3': 'another_value', 'callable_attr_2': print , }
assert self.Config_Class.__cls_kwargs__(include_base_classes=True) == self.Config_Class.__cls_kwargs__()
assert self.Extra_Config.__cls_kwargs__(include_base_classes=True) == self.Extra_Config.__cls_kwargs__()

assert self.Config_Class.__cls_kwargs__( ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print }
assert self.Extra_Config.__cls_kwargs__( ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print ,
'attribute3': 'another_value', 'callable_attr_2': print , }
assert self.Config_Class.__cls_kwargs__() == self.Config_Class.__cls_kwargs__()
assert self.Extra_Config.__cls_kwargs__() == self.Extra_Config.__cls_kwargs__()
assert self.Config_Class.__cls_kwargs__() == self.Config_Class().__cls_kwargs__()
assert self.Extra_Config.__cls_kwargs__() == self.Extra_Config().__cls_kwargs__()

Expand Down

0 comments on commit 09d02d0

Please sign in to comment.