diff --git a/osbot_utils/type_safe/Type_Safe.py b/osbot_utils/type_safe/Type_Safe.py index 28a90c87..82acf9b7 100644 --- a/osbot_utils/type_safe/Type_Safe.py +++ b/osbot_utils/type_safe/Type_Safe.py @@ -27,8 +27,8 @@ def __attr_names__(self): return list_set(self.__locals__()) @classmethod - def __cls_kwargs__(cls, include_base_classes=True): # Return current class dictionary of class level variables and their values - return type_safe_step_class_kwargs.get_cls_kwargs(cls, include_base_classes) + def __cls_kwargs__(cls): # Return current class dictionary of class level variables and their values + return type_safe_step_class_kwargs.get_cls_kwargs(cls) @classmethod def __default__value__(cls, var_type): diff --git a/osbot_utils/type_safe/shared/Type_Safe__Cache.py b/osbot_utils/type_safe/shared/Type_Safe__Cache.py index 5637f308..ded9b86f 100644 --- a/osbot_utils/type_safe/shared/Type_Safe__Cache.py +++ b/osbot_utils/type_safe/shared/Type_Safe__Cache.py @@ -2,6 +2,7 @@ from typing import get_origin from weakref import WeakKeyDictionary from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES +from osbot_utils.utils.Objects import all_annotations__in_class class Type_Safe__Cache: @@ -42,7 +43,7 @@ def get_class_annotations(self, cls): annotations = self._cls__annotations_cache.get(cls) # this is a more efficient cache retrieval pattern (we only get the data from the dict once) if not annotations: # todo: apply this to the other cache getters if self.skip_cache or cls not in self._cls__annotations_cache: - annotations = cls.__annotations__.items() + annotations = all_annotations__in_class(cls).items() self._cls__annotations_cache[cls] = annotations else: self.cache_hit__cls__annotations += 1 diff --git a/osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py b/osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py index 502ce7d9..70df20de 100644 --- a/osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py +++ b/osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py @@ -1,4 +1,7 @@ from typing import Dict, Any, Type + +from osbot_utils.helpers.Random_Guid import Random_Guid + from osbot_utils.type_safe.shared.Type_Safe__Cache import Type_Safe__Cache, type_safe_cache from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES from osbot_utils.type_safe.shared.Type_Safe__Validation import type_safe_validation @@ -13,27 +16,14 @@ class Type_Safe__Step__Class_Kwargs: def __init__(self): self.type_safe_cache = type_safe_cache # Initialize with singleton cache - def get_cls_kwargs(self, cls : Type , # Main entry point for getting class kwargs - include_base_classes : bool = True )\ - -> Dict[str, Any]: # Returns dict of class kwargs + def get_cls_kwargs(self, cls : Type )\ + -> Dict[str, Any]: # Main entry point for getting class kwargs, returns dict of class kwargs if not hasattr(cls, '__mro__'): # Handle non-class inputs return {} - if include_base_classes: - return self.get_cls_kwargs__with_inheritance(cls) - return self.get_cls_kwargs__no_inheritance(cls) - - def get_cls_kwargs__no_inheritance(self, cls : Type )\ - -> Dict[str, Any]: # note: this method is mainly used by tests, so we don't need to add cache support here - kwargs = {} # Process current class only - self.process_mro_class (cls, kwargs) # Handle class variables - self.process_annotations(cls, cls, kwargs) # Process type annotations - return kwargs - - def get_cls_kwargs__with_inheritance(self, cls : Type )\ - -> Dict[str, Any]: # Get class kwargs with inheritance kwargs = type_safe_cache.get_cls_kwargs(cls) # see if we have cached data for this class + if kwargs is not None: return kwargs else: @@ -44,12 +34,19 @@ def get_cls_kwargs__with_inheritance(self, cls : Type )\ self.process_mro_class (base_cls, kwargs) # Handle each class in MRO self.process_annotations(cls, base_cls, kwargs) # Process its annotations - if self.is_kwargs_cacheable(kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars) + if self.is_kwargs_cacheable(cls, kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars) type_safe_cache.set_cache__cls_kwargs(cls, kwargs) # cache it return kwargs - def is_kwargs_cacheable(self, kwargs: Dict[str, Any]) -> bool: - return all(isinstance(value, IMMUTABLE_TYPES) for value in kwargs.values()) + def is_kwargs_cacheable(self, cls, kwargs: Dict[str, Any]) -> bool: + annotations = type_safe_cache.get_class_annotations(cls) + match = all(isinstance(value, IMMUTABLE_TYPES) for value in kwargs.values()) + + if match: # check for special cases that we can't cache (like Random_Guid) + if Random_Guid in list(dict(annotations).values()): # todo: need to add the other special cases (like Timestamp_Now) + return False + return match + def handle_undefined_var(self, cls : Type , # Handle undefined class variables kwargs : Dict[str, Any] , diff --git a/osbot_utils/utils/Objects.py b/osbot_utils/utils/Objects.py index 70175800..51ada239 100644 --- a/osbot_utils/utils/Objects.py +++ b/osbot_utils/utils/Objects.py @@ -490,6 +490,14 @@ def all_annotations(target): annotations.update(base.__annotations__) return annotations +def all_annotations__in_class(target): + annotations = {} + if hasattr(target, '__mro__'): + for base in reversed(target.__mro__): + if hasattr(base, '__annotations__'): + annotations.update(base.__annotations__) + return annotations + def value_type_matches_obj_annotation_for_attr(target, attr_name, value): import typing annotations = all_annotations(target) diff --git a/tests/unit/helpers/trace/test_Trace_Files.py b/tests/unit/helpers/trace/test_Trace_Files.py index 341ab958..d22d59c8 100644 --- a/tests/unit/helpers/trace/test_Trace_Files.py +++ b/tests/unit/helpers/trace/test_Trace_Files.py @@ -29,7 +29,11 @@ def test___default_kwargs__(self): assert trace_files.stack[0] == Trace_Call__Stack_Node(name=DEFAULT_ROOT_NODE_NODE_TITLE) def test___init__(self): - assert Trace_Files.__cls_kwargs__(include_base_classes=False) == {'files': []} + trace_files__cls_kwargs = Trace_Files.__cls_kwargs__() + assert trace_files__cls_kwargs == {'config' : trace_files__cls_kwargs.get('config'), + 'files' : [] , + 'prev_trace_function': None , + 'started' : False } assert Trace_Files().files == [] assert Trace_Files(files=[] ).files == [] diff --git a/tests/unit/testing/performance/test_Performance_Measure__Session.py b/tests/unit/testing/performance/test_Performance_Measure__Session.py index 1bf97b71..70d7e816 100644 --- a/tests/unit/testing/performance/test_Performance_Measure__Session.py +++ b/tests/unit/testing/performance/test_Performance_Measure__Session.py @@ -29,7 +29,7 @@ def setUpClass(cls): cls.time_20_kns = 20_000 def test_measure(self): - class An_Class_1(): + class An_Class_1: pass class An_Class_2(Type_Safe): @@ -52,13 +52,13 @@ class An_Class_6(Type_Safe): print() with Performance_Measure__Session(assert_enabled=True) as _: _.measure(str ).print().assert_time(self.time_100_ns, self.time_0_ns ) - _.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns ) + _.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns , self.time_8_kns ) _.measure(An_Class_1 ).print().assert_time(self.time_100_ns, self.time_200_ns ) _.measure(An_Class_2 ).print().assert_time(self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns ) - _.measure(An_Class_3 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) - _.measure(An_Class_4 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) - _.measure(An_Class_5 ).print().assert_time(self.time_4_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) - _.measure(An_Class_6 ).print().assert_time(self.time_4_kns , self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns ) + _.measure(An_Class_3 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) + _.measure(An_Class_4 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) + _.measure(An_Class_5 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns ) + _.measure(An_Class_6 ).print().assert_time(self.time_4_kns , self.time_5_kns , self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns ) type_safe_cache.print_cache_hits() diff --git a/tests/unit/type_safe/steps/test_perf__Type_Safe__Step__Class_Kwargs.py b/tests/unit/type_safe/steps/test_perf__Type_Safe__Step__Class_Kwargs.py index 04c941c0..afd91357 100644 --- a/tests/unit/type_safe/steps/test_perf__Type_Safe__Step__Class_Kwargs.py +++ b/tests/unit/type_safe/steps/test_perf__Type_Safe__Step__Class_Kwargs.py @@ -73,13 +73,9 @@ class ChildClass(BaseClass): child_int : int = 84 def get_inherited_kwargs_with_base(): # Get kwargs including base class - return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass, include_base_classes=True) - - def get_inherited_kwargs_no_base(): # Get kwargs excluding base class - return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass, include_base_classes=False) + return type_safe_step_class_kwargs.get_cls_kwargs(ChildClass) self.session.measure(get_inherited_kwargs_with_base).assert_time(self.time_6_kns) - self.session.measure(get_inherited_kwargs_no_base ).assert_time(self.time_3_kns) def test_with_methods(self): # Test with instance and class methods class MethodClass: diff --git a/tests/unit/type_safe/test_Type_Safe.py b/tests/unit/type_safe/test_Type_Safe.py index 02aaddcb..8967fcf2 100644 --- a/tests/unit/type_safe/test_Type_Safe.py +++ b/tests/unit/type_safe/test_Type_Safe.py @@ -15,7 +15,7 @@ from osbot_utils.type_safe.steps.Type_Safe__Step__From_Json import type_safe_step_from_json from osbot_utils.utils.Json import json_dumps from osbot_utils.utils.Misc import random_string, list_set -from osbot_utils.utils.Objects import obj_data, __, default_value, serialize_to_dict +from osbot_utils.utils.Objects import obj_data, __, default_value, serialize_to_dict class test_Type_Safe(TestCase): @@ -44,14 +44,12 @@ def test___cls_kwargs__(self): if sys.version_info < (3, 9): pytest.skip("Skipping test that doesn't work on 3.8 or lower") - assert self.Config_Class.__cls_kwargs__(include_base_classes=False) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print } - assert self.Config_Class.__cls_kwargs__(include_base_classes=True ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print } - assert self.Extra_Config.__cls_kwargs__(include_base_classes=False) == {'attribute3': 'another_value', 'callable_attr_2': print } - assert self.Extra_Config.__cls_kwargs__(include_base_classes=True ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print , - 'attribute3': 'another_value', 'callable_attr_2': print , } - assert self.Config_Class.__cls_kwargs__(include_base_classes=True) == self.Config_Class.__cls_kwargs__() - assert self.Extra_Config.__cls_kwargs__(include_base_classes=True) == self.Extra_Config.__cls_kwargs__() + assert self.Config_Class.__cls_kwargs__( ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print } + assert self.Extra_Config.__cls_kwargs__( ) == {'attribute1': 'default_value', 'attribute2': True, 'callable_attr_1': print , + 'attribute3': 'another_value', 'callable_attr_2': print , } + assert self.Config_Class.__cls_kwargs__() == self.Config_Class.__cls_kwargs__() + assert self.Extra_Config.__cls_kwargs__() == self.Extra_Config.__cls_kwargs__() assert self.Config_Class.__cls_kwargs__() == self.Config_Class().__cls_kwargs__() assert self.Extra_Config.__cls_kwargs__() == self.Extra_Config().__cls_kwargs__()