Skip to content

Commit

Permalink
added cache to main cls_kwargs output, which had quite a performance …
Browse files Browse the repository at this point in the history
…impact on the Type_Safe classes that have IMMUTABLE_TYPES (see how the performance of An_Class_2 went from 7_kns to 1_kns)
  • Loading branch information
DinisCruz committed Jan 20, 2025
1 parent 8a19002 commit 0f1e401
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
14 changes: 14 additions & 0 deletions osbot_utils/type_safe/shared/Type_Safe__Cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
class Type_Safe__Cache:

_annotations_cache : WeakKeyDictionary
_cls_kwargs_cache : WeakKeyDictionary
_get_origin_cache : WeakKeyDictionary
_mro_cache : WeakKeyDictionary
_valid_vars_cache : WeakKeyDictionary

cache_hit__annotations : int = 0
cache_hit__cls_kwargs : int = 0
cache_hit__get_origin : int = 0
cache_hit__mro : int = 0
cache_hit__valid_vars : int = 0
Expand All @@ -20,10 +22,18 @@ class Type_Safe__Cache:
# Caching system for Type_Safe methods
def __init__(self):
self._annotations_cache = WeakKeyDictionary() # Cache for class annotations
self._cls_kwargs_cache = WeakKeyDictionary() # Cache for class kwargs
self._get_origin_cache = WeakKeyDictionary() # Cache for get_origin results
self._mro_cache = WeakKeyDictionary() # Cache for Method Resolution Order
self._valid_vars_cache = WeakKeyDictionary()

def get_cls_kwargs(self, cls):
if self.skip_cache or cls not in self._cls_kwargs_cache:
return None
else:
self.cache_hit__cls_kwargs += 1
return self._cls_kwargs_cache.get(cls)

def get_class_annotations(self, cls):
if self.skip_cache or cls not in self._annotations_cache:
self._annotations_cache[cls] = cls.__annotations__.items()
Expand Down Expand Up @@ -58,12 +68,16 @@ def get_valid_class_variables(self, cls, validator):
self.cache_hit__valid_vars += 1
return self._valid_vars_cache[cls]

def set_cache__cls_kwargs(self, cls, kwargs):
self._cls_kwargs_cache[cls] = kwargs
return kwargs

def print_cache_hits(self):
print()
print("###### Type_Safe_Cache Hits ########")
print()
print(f" annotations : {self.cache_hit__annotations}")
print(f" cls_kwargs : {self.cache_hit__cls_kwargs }")
print(f" get_origin : {self.cache_hit__get_origin }")
print(f" mro : {self.cache_hit__mro }")
print(f" valid_vars : {self.cache_hit__valid_vars }")
Expand Down
20 changes: 15 additions & 5 deletions osbot_utils/type_safe/steps/Type_Safe__Step__Class_Kwargs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Any, Type
from osbot_utils.type_safe.shared.Type_Safe__Cache import Type_Safe__Cache, type_safe_cache
from osbot_utils.type_safe.shared.Type_Safe__Shared__Variables import IMMUTABLE_TYPES
from osbot_utils.type_safe.shared.Type_Safe__Validation import type_safe_validation
from osbot_utils.type_safe.steps.Type_Safe__Step__Default_Value import type_safe_step_default_value

Expand All @@ -24,23 +25,32 @@ def get_cls_kwargs(self, cls : Type ,
return self.get_cls_kwargs__no_inheritance(cls)

def get_cls_kwargs__no_inheritance(self, cls : Type )\
-> Dict[str, Any]:
-> Dict[str, Any]: # note: this method is mainly used by tests, so we don't need to add cache support here
kwargs = {} # Process current class only
self.process_mro_class (cls, kwargs) # Handle class variables
self.process_annotations(cls, cls, kwargs) # Process type annotations

self.process_mro_class (cls, kwargs) # Handle class variables
self.process_annotations(cls, cls, kwargs) # Process type annotations
return kwargs

def get_cls_kwargs__with_inheritance(self, cls : Type )\
-> Dict[str, Any]: # Get class kwargs with inheritance
kwargs = type_safe_cache.get_cls_kwargs(cls) # see if we have cached data for this class
if kwargs is not None:
return kwargs
else:
kwargs = {}

base_classes = type_safe_cache.get_class_mro(cls)
kwargs = {} # Process inheritance chain
for base_cls in base_classes:
self.process_mro_class (base_cls, kwargs) # Handle each class in MRO
self.process_annotations(cls, base_cls, kwargs) # Process its annotations

if self.is_kwargs_cacheable(kwargs): # if we can cache it (i.e. only IMMUTABLE_TYPES vars)
type_safe_cache.set_cache__cls_kwargs(cls, kwargs) # cache it
return kwargs

def is_kwargs_cacheable(self, kwargs: Dict[str, Any]) -> bool:
return all(isinstance(value, IMMUTABLE_TYPES) for value in kwargs.values())

def handle_undefined_var(self, cls : Type , # Handle undefined class variables
kwargs : Dict[str, Any] ,
var_name : str ,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def setUpClass(cls):
pytest.skip("Skipping tests in Github Actions")
cls.time_0_ns = 0
cls.time_100_ns = 100
cls.time_1_kns = 1_000
cls.time_2_kns = 2_000
cls.time_3_kns = 3_000
cls.time_4_kns = 4_000
cls.time_5_kns = 5_000
Expand Down Expand Up @@ -51,11 +53,11 @@ class An_Class_6(Type_Safe):
_.measure(str ).print().assert_time(self.time_100_ns, self.time_0_ns )
_.measure(Random_Guid).print().assert_time(self.time_3_kns , self.time_5_kns, self.time_6_kns , self.time_7_kns )
_.measure(An_Class_1 ).print().assert_time(self.time_100_ns )
_.measure(An_Class_2 ).print().assert_time(self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
_.measure(An_Class_2 ).print().assert_time(self.time_1_kns , self.time_2_kns , self.time_3_kns , self.time_4_kns , self.time_5_kns , self.time_6_kns, self.time_7_kns )
_.measure(An_Class_3 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_4 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_5 ).print().assert_time(self.time_8_kns , self.time_9_kns ,self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_7_kns , self.time_8_kns , self.time_10_kns, self.time_20_kns )
_.measure(An_Class_6 ).print().assert_time(self.time_7_kns , self.time_8_kns ,self.time_9_kns , self.time_10_kns, self.time_20_kns )

type_safe_cache.print_cache_hits()

Expand All @@ -75,4 +77,4 @@ class An_Class_6(Type_Safe):
# cache__class_kwargs = Cache__Class_Kwargs()
# import dis
# #dis.dis(Type_Safe.__cls_kwargs__)
# dis.dis(type_safe_step_class_kwargs.get_cls_kwargs)
# dis.dis(type_safe_step_class_kwargs.get_cls_kwargs)

0 comments on commit 0f1e401

Please sign in to comment.