From add05d902b2718561388f568b661b28d2416bfd9 Mon Sep 17 00:00:00 2001 From: manuelhsantana Date: Tue, 2 Jan 2024 19:08:48 -0800 Subject: [PATCH] Added TensorDB module docstrings for API documentation --- openfl/databases/tensor_db.py | 74 ++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/openfl/databases/tensor_db.py b/openfl/databases/tensor_db.py index 0045569d6a..79f8ac6d27 100644 --- a/openfl/databases/tensor_db.py +++ b/openfl/databases/tensor_db.py @@ -29,7 +29,7 @@ class TensorDB: """ def __init__(self) -> None: - """Initialize.""" + """Initializes a new instance of the TensorDB class.""" types_dict = { 'tensor_name': 'string', 'origin': 'string', @@ -46,8 +46,10 @@ def __init__(self) -> None: self.mutex = Lock() def _bind_convenience_methods(self): - # Bind convenience methods for TensorDB dataframe - # to make storage, retrieval, and search easier + """ + Bind convenience methods for the TensorDB dataframe + to make storage, retrieval, and search easier. + """ if not hasattr(self.tensor_db, 'store'): self.tensor_db.store = MethodType(_store, self.tensor_db) if not hasattr(self.tensor_db, 'retrieve'): @@ -56,17 +58,32 @@ def _bind_convenience_methods(self): self.tensor_db.search = MethodType(_search, self.tensor_db) def __repr__(self) -> str: - """Representation of the object.""" + """ + Returns the string representation of the TensorDB object. + + Returns: + content (str): The string representation of the TensorDB object. + """ with pd.option_context('display.max_rows', None): content = self.tensor_db[['tensor_name', 'origin', 'round', 'report', 'tags']] return f'TensorDB contents:\n{content}' def __str__(self) -> str: - """Printable string representation.""" + """ + Returns the string representation of the TensorDB object. + + Returns: + __repr__ (str): The string representation of the TensorDB object. + """ return self.__repr__() def clean_up(self, remove_older_than: int = 1) -> None: - """Remove old entries from database preventing the db from becoming too large and slow.""" + """ + Removes old entries from the database to prevent it from becoming too large and slow. + + Args: + remove_older_than (int, optional): Entries older than this number of rounds are removed. Defaults to 1. + """ if remove_older_than < 0: # Getting a negative argument calls off cleaning return @@ -79,10 +96,11 @@ def clean_up(self, remove_older_than: int = 1) -> None: ].reset_index(drop=True) def cache_tensor(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None: - """Insert tensor into TensorDB (dataframe). + """ + Insert a tensor into TensorDB (dataframe). Args: - tensor_key_dict: The Tensor Key + tensor_key_dict (Dict[TensorKey, np.ndarray]): A dictionary where the key is a TensorKey and the value is a numpy array. Returns: None @@ -107,8 +125,11 @@ def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]: """ Perform a lookup of the tensor_key in the TensorDB. - Returns the nparray if it is available - Otherwise, it returns 'None' + Args: + tensor_key (TensorKey): The key of the tensor to look up. + + Returns: + Optional[np.ndarray]: The numpy array if it is available. Otherwise, returns None. """ tensor_name, origin, fl_round, report, tags = tensor_key @@ -129,20 +150,19 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict: """ Determine whether all of the collaborator tensors are present for a given tensor key. - Returns their weighted average. - Args: - tensor_key: The tensor key to be resolved. If origin 'agg_uuid' is - present, can be returned directly. Otherwise must - compute weighted average of all collaborators - collaborator_weight_dict: List of collaborator names in federation - and their respective weights - aggregation_function: Call the underlying numpy aggregation - function. Default is just the weighted - average. + tensor_key (TensorKey): The tensor key to be resolved. If origin 'agg_uuid' is + present, can be returned directly. Otherwise must compute weighted + average of all collaborators. + collaborator_weight_dict (dict): A dictionary where the keys are collaborator + names and the values are their respective weights. + aggregation_function (AggregationFunction): Call the underlying numpy aggregation + function to use to compute the weighted average. Default is just the + weighted average. Returns: - weighted_nparray if all collaborator values are present - None if not all values are present + agg_nparray Optional[np.ndarray]: weighted_nparray The weighted average if all + collaborator values are present. Otherwise, returns None. + None: if not all values are present. """ if len(collaborator_weight_dict) != 0: @@ -208,6 +228,16 @@ def get_aggregated_tensor(self, tensor_key: TensorKey, collaborator_weight_dict: return np.array(agg_nparray) def _iterate(self, order_by: str = 'round', ascending: bool = False) -> Iterator[pd.Series]: + """ + Returns an iterator over the rows of the TensorDB, sorted by a specified column. + + Args: + order_by (str, optional): The column to sort by. Defaults to 'round'. + ascending (bool, optional): Whether to sort in ascending order. Defaults to False. + + Returns: + Iterator[pd.Series]: An iterator over the rows of the TensorDB. + """ columns = ['round', 'nparray', 'tensor_name', 'tags'] rows = self.tensor_db[columns].sort_values(by=order_by, ascending=ascending).iterrows() for _, row in rows: