Skip to content

Commit

Permalink
feat(cache): ✨ Add cache_group that can segment an instance cache int…
Browse files Browse the repository at this point in the history
…o different isolated parts. (#66)

Is useful for limiting the number of cache entries while also allowing
unique operations which perform different actions to be cache at the
same time.
  • Loading branch information
ErikBavenstrand committed Jun 26, 2023
2 parents fa60b3e + ea63943 commit 5fa8c9c
Show file tree
Hide file tree
Showing 35 changed files with 294 additions and 338 deletions.
260 changes: 37 additions & 223 deletions examples/Experiment.ipynb

Large diffs are not rendered by default.

44 changes: 36 additions & 8 deletions mleko/cache/cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,24 @@ def _cached_execute(
self,
lambda_func: Callable[[], Any],
cache_keys: list[Hashable | tuple[Any, BaseFingerprinter]],
cache_group: str | None = None,
force_recompute: bool = False,
) -> Any:
"""Executes the given function, caching the results based on the provided cache keys and fingerprints.
Warning:
The cache group is used to group related cache keys together to prevent collisions between cache keys
originating from the same method. For example, if a method is called during the training and testing
phases of a machine learning pipeline, the cache keys for the training and testing phases should be
using different cache groups to prevent collisions between the cache keys for the two phases. Otherwise,
the later cache keys might overwrite the earlier cache entries.
Args:
lambda_func: A lambda function to execute.
cache_keys: A list of cache keys that can be a mix of hashable values and tuples containing a value and a
BaseFingerprinter instance for generating fingerprints.
cache_group: A string representing the cache group, used to group related cache keys together when methods
are called independently.
force_recompute: A boolean indicating whether to force recompute the result and update the cache, even if a
cached result is available.
Expand All @@ -116,7 +126,7 @@ def _cached_execute(
"""
frame_qualname = get_frame_qualname(inspect.stack()[1])
class_method_name = ".".join(frame_qualname.split(".")[-2:])
cache_key = self._compute_cache_key(cache_keys, frame_qualname)
cache_key = self._compute_cache_key(cache_keys, class_method_name, cache_group)

if not force_recompute:
output = self._load_from_cache(cache_key)
Expand All @@ -139,14 +149,22 @@ def _cached_execute(
return self._load_from_cache(cache_key)

def _compute_cache_key(
self, cache_keys: list[Hashable | tuple[Any, BaseFingerprinter]], frame_qualname: str
self,
cache_keys: list[Hashable | tuple[Any, BaseFingerprinter]],
class_method_name: str,
cache_group: str | None = None,
) -> str:
"""Computes the cache key based on the provided cache keys and the calling function's fully qualified name.
Args:
cache_keys: A list of cache keys that can be a mix of hashable values and tuples containing a value and a
BaseFingerprinter instance for generating fingerprints.
frame_qualname: The fully qualified name of the cached function stack frame.
class_method_name: A string of format "class.method" for class methods or "module.function" for
functions, representing the fully qualified name of the calling function or method.
cache_group: A string representing the cache group.
Raises:
ValueError: If the computed cache key is too long.
Returns:
A string representing the computed cache key, which is the MD5 hash of the fully qualified name of the
Expand All @@ -161,12 +179,22 @@ def _compute_cache_key(
else:
values_to_hash.append(key)

data = pickle.dumps((frame_qualname, values_to_hash))

class_method_name = ".".join(frame_qualname.split(".")[-2:])
cache_key = f"{class_method_name}.{hashlib.md5(data).hexdigest()}"
data = pickle.dumps(values_to_hash)
cache_key_prefix = class_method_name
if cache_group is not None:
cache_key_prefix = f"{cache_key_prefix}.{cache_group}"

cache_key = f"{cache_key_prefix}.{hashlib.md5(data).hexdigest()}"
if len(cache_key) + 1 + len(self._cache_file_suffix) > 255:
raise ValueError(
f"The computed cache key is too long ({len(cache_key) + len(self._cache_file_suffix)} chars)."
"The maximum length of a cache key is 255 chars, and given the current class, the maximum "
"length of the provided cache_group is "
f"{255 - len(cache_key_prefix) - 32 - 1 - len(self._cache_file_suffix)} chars."
"Please reduce the length of the cache_group."
)

return cache_key
return f"{cache_key_prefix}.{hashlib.md5(data).hexdigest()}"

def _read_cache_file(self, cache_file_path: Path) -> Any:
"""Reads the cache file from the specified path and returns the deserialized data.
Expand Down
66 changes: 42 additions & 24 deletions mleko/cache/lru_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import inspect
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, cache_directory: str | Path, cache_file_suffix: str, cache_si
"""
super().__init__(cache_directory, cache_file_suffix)
self._cache_size = cache_size
self._cache: OrderedDict[str, bool] = OrderedDict()
self._cache: dict[str, OrderedDict[str, bool]] = defaultdict(OrderedDict)
self._load_cache_from_disk()

def _load_cache_from_disk(self) -> None:
Expand All @@ -79,23 +79,30 @@ def _load_cache_from_disk(self) -> None:
"""
frame_qualname = get_frame_qualname(inspect.stack()[2])
class_name = frame_qualname.split(".")[-2]
file_name_pattern = rf"{class_name}\.[a-zA-Z_][a-zA-Z0-9_]*\.[a-fA-F\d]{{32}}"
file_name_pattern = rf"{class_name}\.([a-zA-Z_][a-zA-Z0-9_]*)(\.[a-zA-Z_][a-zA-Z0-9_]*)?\.[a-fA-F\d]{{32}}"

cache_files = [
f
for f in self._cache_directory.glob(f"*.{self._cache_file_suffix}")
if re.search(file_name_pattern, str(f.stem))
]
ordered_cache_files = sorted(cache_files, key=lambda x: x.stat().st_mtime)

for cache_file in ordered_cache_files:
cache_key_match = re.search(file_name_pattern, cache_file.stem)
cache_key = cache_key_match.group(0) # type: ignore
if cache_key not in self._cache:
if len(self._cache) >= self._cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
for file in self._cache_directory.glob(f"{oldest_key}*.{self._cache_file_suffix}"):
file.unlink()
self._cache[cache_key] = True
if cache_key_match:
method_name, cache_group = cache_key_match.groups()
group_identifier = method_name + cache_group if cache_group else method_name
cache_key = cache_key_match.group(0)

if cache_key not in self._cache[group_identifier]:
if len(self._cache[group_identifier]) >= self._cache_size:
oldest_key = next(iter(self._cache[group_identifier]))
del self._cache[group_identifier][oldest_key]
for file in self._cache_directory.glob(f"{oldest_key}*.{self._cache_file_suffix}"):
file.unlink()

self._cache[group_identifier][cache_key] = True

def _load_from_cache(self, cache_key: str) -> Any | None:
"""Loads data from the cache based on the provided cache key and updates the LRU cache.
Expand All @@ -106,9 +113,11 @@ def _load_from_cache(self, cache_key: str) -> Any | None:
Returns:
The cached data if it exists, or None if there is no data for the given cache key.
"""
if cache_key in self._cache:
self._cache.move_to_end(cache_key)
return super()._load_from_cache(cache_key)
for group_identifier in self._cache.keys():
if cache_key in self._cache[group_identifier]:
self._cache[group_identifier].move_to_end(cache_key)
return super()._load_from_cache(cache_key)
return None

def _save_to_cache(self, cache_key: str, output: Any) -> None:
"""Saves the given data to the cache using the provided cache key, updating the LRU cache accordingly.
Expand All @@ -119,13 +128,22 @@ def _save_to_cache(self, cache_key: str, output: Any) -> None:
cache_key: A string representing the cache key.
output: The data to be saved to the cache.
"""
if cache_key not in self._cache:
if len(self._cache) >= self._cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
for file in self._cache_directory.glob(f"{oldest_key}*.{self._cache_file_suffix}"):
file.unlink()
self._cache[cache_key] = True
else:
self._cache.move_to_end(cache_key)
super()._save_to_cache(cache_key, output)
cache_key_match = re.match(
r"[a-zA-Z_][a-zA-Z0-9_]*\.([a-zA-Z_][a-zA-Z0-9_]*)(\.[a-zA-Z_][a-zA-Z0-9_]*)?\.[a-fA-F\d]{32}", cache_key
)
if cache_key_match:
method_name, cache_group = cache_key_match.groups()
group_identifier = method_name + cache_group if cache_group else method_name

if cache_key not in self._cache[group_identifier]:
if len(self._cache[group_identifier]) >= self._cache_size:
oldest_key = next(iter(self._cache[group_identifier]))
del self._cache[group_identifier][oldest_key]
for file in self._cache_directory.glob(f"{oldest_key}*.{self._cache_file_suffix}"):
file.unlink()

self._cache[group_identifier][cache_key] = True
else:
self._cache[group_identifier].move_to_end(cache_key)

super()._save_to_cache(cache_key, output)
5 changes: 4 additions & 1 deletion mleko/dataset/convert/base_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def __init__(self, cache_directory: str | Path, cache_size: int):
LRUCacheMixin.__init__(self, cache_directory, self._cache_file_suffix, cache_size)

@abstractmethod
def convert(self, file_paths: list[Path] | list[str], force_recompute: bool = False) -> vaex.DataFrame:
def convert(
self, file_paths: list[Path] | list[str], cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Abstract method to convert the input file paths to the desired output format.
Args:
file_paths: A list of input file paths to be converted.
cache_group: The cache group to use.
force_recompute: If set to True, forces recomputation and ignores the cache.
Returns:
Expand Down
6 changes: 5 additions & 1 deletion mleko/dataset/convert/csv_to_vaex_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def __init__(
self._num_workers = num_workers
self._random_state = random_state

def convert(self, file_paths: list[Path] | list[str], force_recompute: bool = False) -> vaex.DataFrame:
def convert(
self, file_paths: list[Path] | list[str], cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Converts a list of CSV files to Arrow format and returns a `vaex` dataframe joined from the converted data.
The method takes care of caching, and results will be reused accordingly unless `force_recompute`
Expand All @@ -131,6 +133,7 @@ def convert(self, file_paths: list[Path] | list[str], force_recompute: bool = Fa
Args:
file_paths: A list of file paths to be converted.
cache_group: The cache group to use.
force_recompute: If set to True, forces recomputation and ignores the cache.
Returns:
Expand All @@ -149,6 +152,7 @@ def convert(self, file_paths: list[Path] | list[str], force_recompute: bool = Fa
self._downcast_float,
(file_paths, CSVFingerprinter(n_rows=100_000 // len(file_paths))),
],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
5 changes: 4 additions & 1 deletion mleko/dataset/feature_select/base_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def __init__(
self._ignore_features: tuple[str, ...] = tuple(ignore_features) if ignore_features is not None else tuple()

@abstractmethod
def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects features from the given DataFrame.
Args:
dataframe: DataFrame from which to select features.
cache_group: The cache group to use.
force_recompute: Whether to force the feature selector to recompute its output, even if it already exists.
Raises:
Expand Down
6 changes: 5 additions & 1 deletion mleko/dataset/feature_select/composite_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,14 @@ def __init__(
super().__init__(cache_directory, None, None, cache_size)
self._feature_selectors = tuple(feature_selectors)

def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects the features from the DataFrame.
Args:
dataframe: DataFrame from which the features will be selected.
cache_group: The cache group to use for caching.
force_recompute: If True, the features will be recomputed even if they are cached.
Returns:
Expand All @@ -91,6 +94,7 @@ def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = Fal
return self._cached_execute(
lambda_func=lambda: self._select_features(dataframe),
cache_keys=[self._fingerprint(), (dataframe, VaexFingerprinter())],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
6 changes: 5 additions & 1 deletion mleko/dataset/feature_select/invariance_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ def __init__(
"""
super().__init__(cache_directory, features, ignore_features, cache_size)

def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects features based on invariance.
Args:
dataframe: The DataFrame to select features from.
cache_group: The cache group to use for caching.
force_recompute: Whether to force recompute the selected features.
Returns:
Expand All @@ -82,6 +85,7 @@ def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = Fal
self._fingerprint(),
(dataframe, VaexFingerprinter()),
],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,16 @@ def __init__(
super().__init__(cache_directory, features, ignore_features, cache_size)
self._missing_rate_threshold = missing_rate_threshold

def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects features based on the missing rate.
Will cache the result of the feature selection.
Args:
dataframe: The DataFrame to select features from.
cache_group: The cache group to use.
force_recompute: Whether to force recompute the feature selection.
Returns:
Expand All @@ -84,6 +87,7 @@ def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = Fal
return self._cached_execute(
lambda_func=lambda: self._select_features(dataframe),
cache_keys=[self._fingerprint(), (dataframe, VaexFingerprinter())],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def __init__(
super().__init__(cache_directory, features, ignore_features, cache_size)
self._correlation_threshold = correlation_threshold

def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects features based on the Pearson correlation.
Args:
dataframe: The DataFrame to select features from.
cache_group: The cache group to use.
force_recompute: Whether to force recompute the selected features.
Returns:
Expand All @@ -85,6 +88,7 @@ def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = Fal
return self._cached_execute(
lambda_func=lambda: self._select_features(dataframe),
cache_keys=[self._fingerprint(), (dataframe, VaexFingerprinter())],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
6 changes: 5 additions & 1 deletion mleko/dataset/feature_select/variance_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def __init__(
super().__init__(cache_directory, features, ignore_features, cache_size)
self._variance_threshold = variance_threshold

def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> vaex.DataFrame:
def select_features(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> vaex.DataFrame:
"""Selects features based on the variance.
Args:
dataframe: The DataFrame to select features from.
cache_group: The cache group to use.
force_recompute: Whether to force recompute the selected features.
Returns:
Expand All @@ -88,6 +91,7 @@ def select_features(self, dataframe: vaex.DataFrame, force_recompute: bool = Fal
self._fingerprint(),
(dataframe, VaexFingerprinter()),
],
cache_group=cache_group,
force_recompute=force_recompute,
)

Expand Down
5 changes: 4 additions & 1 deletion mleko/dataset/split/base_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def __init__(self, cache_directory: str | Path, cache_size: int):
LRUCacheMixin.__init__(self, cache_directory, self._cache_file_suffix, cache_size)

@abstractmethod
def split(self, dataframe: vaex.DataFrame, force_recompute: bool = False) -> tuple[vaex.DataFrame, vaex.DataFrame]:
def split(
self, dataframe: vaex.DataFrame, cache_group: str | None = None, force_recompute: bool = False
) -> tuple[vaex.DataFrame, vaex.DataFrame]:
"""Abstract method to split the given dataframe into two parts.
Args:
dataframe: The dataframe to be split.
cache_group: The cache group to use.
force_recompute: Forces recomputation if True, otherwise reads from the cache if available.
Returns:
Expand Down
Loading

0 comments on commit 5fa8c9c

Please sign in to comment.