Skip to content

Commit

Permalink
Added TensorDB module docstrings for API documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelhsantana committed Jan 3, 2024
1 parent d5aa87b commit add05d9
Showing 1 changed file with 52 additions and 22 deletions.
74 changes: 52 additions & 22 deletions openfl/databases/tensor_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TensorDB:
"""

def __init__(self) -> None:
"""Initialize."""
"""Initializes a new instance of the TensorDB class."""
types_dict = {
'tensor_name': 'string',
'origin': 'string',
Expand All @@ -46,8 +46,10 @@ def __init__(self) -> None:
self.mutex = Lock()

def _bind_convenience_methods(self):
# Bind convenience methods for TensorDB dataframe
# to make storage, retrieval, and search easier
"""
Bind convenience methods for the TensorDB dataframe
to make storage, retrieval, and search easier.
"""
if not hasattr(self.tensor_db, 'store'):
self.tensor_db.store = MethodType(_store, self.tensor_db)
if not hasattr(self.tensor_db, 'retrieve'):
Expand All @@ -56,17 +58,32 @@ def _bind_convenience_methods(self):
self.tensor_db.search = MethodType(_search, self.tensor_db)

def __repr__(self) -> str:
"""Representation of the object."""
"""
Returns the string representation of the TensorDB object.
Returns:
content (str): The string representation of the TensorDB object.
"""
with pd.option_context('display.max_rows', None):
content = self.tensor_db[['tensor_name', 'origin', 'round', 'report', 'tags']]
return f'TensorDB contents:\n{content}'

def __str__(self) -> str:
"""Printable string representation."""
"""
Returns the string representation of the TensorDB object.
Returns:
__repr__ (str): The string representation of the TensorDB object.
"""
return self.__repr__()

def clean_up(self, remove_older_than: int = 1) -> None:
"""Remove old entries from database preventing the db from becoming too large and slow."""
"""
Removes old entries from the database to prevent it from becoming too large and slow.
Args:
remove_older_than (int, optional): Entries older than this number of rounds are removed. Defaults to 1.
"""
if remove_older_than < 0:
# Getting a negative argument calls off cleaning
return
Expand All @@ -79,10 +96,11 @@ def clean_up(self, remove_older_than: int = 1) -> None:
].reset_index(drop=True)

def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None:
"""Insert tensor into TensorDB (dataframe).
"""
Insert a tensor into TensorDB (dataframe).
Args:
tensor_key_dict: The Tensor Key
tensor_key_dict (Dict[TensorKey, np.ndarray]): A dictionary where the key is a TensorKey and the value is a numpy array.
Returns:
None
Expand All @@ -107,8 +125,11 @@ def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]:
"""
Perform a lookup of the tensor_key in the TensorDB.
Returns the nparray if it is available
Otherwise, it returns 'None'
Args:
tensor_key (TensorKey): The key of the tensor to look up.
Returns:
Optional[np.ndarray]: The numpy array if it is available. Otherwise, returns None.
"""
tensor_name, origin, fl_round, report, tags = tensor_key

Expand All @@ -129,20 +150,19 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict:
"""
Determine whether all of the collaborator tensors are present for a given tensor key.
Returns their weighted average.
Args:
tensor_key: The tensor key to be resolved. If origin 'agg_uuid' is
present, can be returned directly. Otherwise must
compute weighted average of all collaborators
collaborator_weight_dict: List of collaborator names in federation
and their respective weights
aggregation_function: Call the underlying numpy aggregation
function. Default is just the weighted
average.
tensor_key (TensorKey): The tensor key to be resolved. If origin 'agg_uuid' is
present, can be returned directly. Otherwise must compute weighted
average of all collaborators.
collaborator_weight_dict (dict): A dictionary where the keys are collaborator
names and the values are their respective weights.
aggregation_function (AggregationFunction): Call the underlying numpy aggregation
function to use to compute the weighted average. Default is just the
weighted average.
Returns:
weighted_nparray if all collaborator values are present
None if not all values are present
agg_nparray Optional[np.ndarray]: weighted_nparray The weighted average if all
collaborator values are present. Otherwise, returns None.
None: if not all values are present.
"""
if len(collaborator_weight_dict) != 0:
Expand Down Expand Up @@ -208,6 +228,16 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict:
return np.array(agg_nparray)

def _iterate(self, order_by: str = 'round', ascending: bool = False) -> Iterator[pd.Series]:
"""
Returns an iterator over the rows of the TensorDB, sorted by a specified column.
Args:
order_by (str, optional): The column to sort by. Defaults to 'round'.
ascending (bool, optional): Whether to sort in ascending order. Defaults to False.
Returns:
Iterator[pd.Series]: An iterator over the rows of the TensorDB.
"""
columns = ['round', 'nparray', 'tensor_name', 'tags']
rows = self.tensor_db[columns].sort_values(by=order_by, ascending=ascending).iterrows()
for _, row in rows:
Expand Down

0 comments on commit add05d9

Please sign in to comment.