Skip to content

Commit

Permalink
Improved numpy type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
bccheung committed May 29, 2023
1 parent 90695dc commit 84af3c6
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions balsa/routines/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,24 @@ def decorator(func):
return decorator
prange = range

try:
from numpy.typing import NDArray
except ImportError:
NDArray = np.ndarray

EPS = 1.0e-7


def matrix_balancing_1d(m: np.ndarray, a: np.ndarray, axis: int) -> np.ndarray:
def matrix_balancing_1d(m: NDArray, a: NDArray, axis: int) -> NDArray:
"""Balances a matrix using a single constraint.
Args:
m (numpy.ndarray): The matrix (a 2-dimensional ndarray) to be balanced
a (numpy.ndarray): The totals vector (a 1-dimensional ndarray) constraint
m (NDArray): The matrix (a 2-dimensional ndarray) to be balanced
a (NDArray): The totals vector (a 1-dimensional ndarray) constraint
axis (int): Direction to constrain (0 = along columns, 1 = along rows)
Return:
numpy.ndarray: A balanced matrix
NDArray: A balanced matrix
"""

assert axis in [0, 1], "axis must be either 0 or 1"
Expand All @@ -42,16 +47,16 @@ def matrix_balancing_1d(m: np.ndarray, a: np.ndarray, axis: int) -> np.ndarray:
return _balance(m, a, axis)


def matrix_balancing_2d(m: Union[np.ndarray, pd.DataFrame], a: np.ndarray, b: np.ndarray, *,
totals_to_use: str = 'raise', max_iterations: int = 1000, rel_error: float = 0.0001,
n_procs: int = 1) -> Tuple[Union[np.ndarray, pd.DataFrame], float, int]:
def matrix_balancing_2d(m: Union[NDArray, pd.DataFrame], a: NDArray, b: NDArray, *, totals_to_use: str = 'raise',
max_iterations: int = 1000, rel_error: float = 0.0001,
n_threads: int = 1) -> Tuple[Union[NDArray, pd.DataFrame], float, int]:
"""Balances a two-dimensional matrix using iterative proportional fitting.
Args:
m (numpy.ndarray | pandas.DataFrame): The matrix (a 2-dimensional ndarray) to be balanced. If a DataFrame
m (NDArray | pandas.DataFrame): The matrix (a 2-dimensional ndarray) to be balanced. If a DataFrame
is supplied, the output will be returned as a DataFrame.
a (numpy.ndarray): The row totals (a 1-dimensional ndarray) to use for balancing
b (numpy.ndarray): The column totals (a 1-dimensional ndarray) to use for balancing
a (NDArray): The row totals (a 1-dimensional ndarray) to use for balancing
b (NDArray): The column totals (a 1-dimensional ndarray) to use for balancing
totals_to_use (str, optional): Defaults to ``'raise'``. Describes how to scale the row and column totals if
their sums do not match. Must be one of ['rows', 'columns', 'average', 'raise'].
- rows: scales the columns totals so that their sums matches the row totals
Expand All @@ -60,13 +65,13 @@ def matrix_balancing_2d(m: Union[np.ndarray, pd.DataFrame], a: np.ndarray, b: np
- raise: raises an Exception if the sums of the row and column totals do not match
max_iterations (int, optional): Defaults to ``1000``. Maximum number of iterations
rel_error (float, optional): Defaults to ``1.0E-4``. Relative error stopping criteria
n_procs (int, optional): Defaults to ``1``. Number of processors for parallel computation. (Not used)
n_threads (int, optional): Defaults to ``1``. Number of processors for parallel computation. (Not used)
Return:
Tuple[numpy.ndarray | pandas.DataFrame, float, int]: The balanced matrix, residual, and n_iterations
Tuple[NDArray | pandas.DataFrame, float, int]: The balanced matrix, residual, and n_iterations
"""
max_iterations = int(max_iterations)
n_procs = int(n_procs)
n_threads = int(n_threads)

# Test if matrix is Pandas DataFrame
data_type = ''
Expand All @@ -87,7 +92,7 @@ def matrix_balancing_2d(m: Union[np.ndarray, pd.DataFrame], a: np.ndarray, b: np
# - totals_to_use is one of ['rows', 'columns', 'average']
# - the max_iterations is a +'ve integer
# - rel_error is a +'ve float between 0 and 1
# - the n_procs is a +'ve integer between 1 and the number of available processors
# - the n_threads is a +'ve integer between 1 and the number of available processors
# ##################################################################################
valid_totals_to_use = ['rows', 'columns', 'average', 'raise']
assert m.ndim == 2 and m.shape[0] == m.shape[1], "m must be a two-dimensional square matrix"
Expand All @@ -98,9 +103,9 @@ def matrix_balancing_2d(m: Union[np.ndarray, pd.DataFrame], a: np.ndarray, b: np
assert totals_to_use in valid_totals_to_use, "totals_to_use must be one of %s" % valid_totals_to_use
assert max_iterations >= 1, "max_iterations must be integer >= 1"
assert 0 < rel_error < 1.0, "rel_error must be float between 0.0 and 1.0"
assert 1 <= n_procs <= cpu_count(), \
"n_procs must be integer between 1 and the number of processors (%d) " % cpu_count()
if n_procs > 1:
assert 1 <= n_threads <= cpu_count(), \
"n_threads must be integer between 1 and the number of processors (%d) " % cpu_count()
if n_threads > 1:
raise NotImplementedError("Multiprocessing capability is not implemented yet.")

# Scale row and column totals, if required
Expand Down Expand Up @@ -137,16 +142,16 @@ def matrix_balancing_2d(m: Union[np.ndarray, pd.DataFrame], a: np.ndarray, b: np
return m, err, i


def _balance(matrix: np.ndarray, tot: np.ndarray, axis: int) -> np.ndarray:
def _balance(matrix: NDArray, tot: NDArray, axis: int) -> NDArray:
"""Balances a matrix using a single constraint.
Args:
matrix (numpy.ndarray): The matrix to be balanced
tot (numpy.ndarray): The totals constraint
matrix (NDArray): The matrix to be balanced
tot (NDArray): The totals constraint
axis (int): Direction to constrain (0 = along columns, 1 = along rows)
Return:
numpy.ndarray: The balanced matrix
NDArray: The balanced matrix
"""
sc = tot / (matrix.sum(axis) + EPS)
sc = np.nan_to_num(sc) # replace divide by 0 errors from the prev. line
Expand Down Expand Up @@ -176,17 +181,16 @@ def _nbf_bucket_round(a_, decimals=0):
return b.reshape(a_.shape)


def matrix_bucket_rounding(m: Union[np.ndarray, pd.DataFrame], *,
decimals: int = 0) -> Union[np.ndarray, pd.DataFrame]:
def matrix_bucket_rounding(m: Union[NDArray, pd.DataFrame], *, decimals: int = 0) -> Union[NDArray, pd.DataFrame]:
"""Bucket rounds to the given number of decimals.
Args:
m (numpy.ndarray | pandas.DataFrame): The matrix to be rounded
m (NDArray | pandas.DataFrame): The matrix to be rounded
decimals (int, optional): Defaults to ``0``. Number of decimal places to round to. If decimals is negative, it
specifies the number of positions to the left of the decimal point.
Return:
numpy.ndarray | pandas.DataFrame: The rounded matrix
NDArray | pandas.DataFrame: The rounded matrix
"""

# Test if matrix is Pandas DataFrame
Expand Down Expand Up @@ -284,23 +288,23 @@ def split_zone_in_matrix(base_matrix: pd.DataFrame, old_zone: int, new_zones: Li
return new_matrix


def aggregate_matrix(matrix: Union[pd.DataFrame, pd.Series], *, groups: Union[pd.Series, np.ndarray] = None,
row_groups: Union[pd.Series, np.ndarray] = None, col_groups: Union[pd.Series, np.ndarray] = None,
def aggregate_matrix(matrix: Union[pd.DataFrame, pd.Series], *, groups: Union[pd.Series, NDArray] = None,
row_groups: Union[pd.Series, NDArray] = None, col_groups: Union[pd.Series, NDArray] = None,
aggfunc: Callable[[Iterable[Union[int, float]]], Union[int, float]] = np.sum
) -> Union[pd.DataFrame, pd.Series]:
"""Aggregates a matrix based on mappings provided for each axis, using a specified aggregation function.
Args:
matrix (pandas.DataFrame | pandas.Series): Matrix data to aggregate. DataFrames and Series with 2-level
indices are supported
groups (pandas.Series | numpy.ndarray, optional): Syntactic sugar to specify both row_groups and
groups (pandas.Series | NDArray, optional): Syntactic sugar to specify both row_groups and
col_groups to use the same grouping series.
row_groups (pandas.Series | numpy.ndarray, optional): Groups for the rows. If aggregating a DataFrame,
row_groups (pandas.Series | NDArray, optional): Groups for the rows. If aggregating a DataFrame,
this must match the index of the matrix. For a "tall" matrix, this series can match either the "full" index
of the series, or it can match the first level of the matrix (it would be the same as if aggregating a
DataFrame). Alternatively, an array can be provided, but it must be the same length as the DataFrame's
index, or the full length of the Series.
col_groups (pandas.Series | numpy.ndarray, optional): Groups for the columns. If aggregating a DataFrame,
col_groups (pandas.Series | NDArray, optional): Groups for the columns. If aggregating a DataFrame,
this must match the columns of the matrix. For a "tall" matrix, this series can match either the "full"
index of the series, or it can match the second level of the matrix (it would be the same as if aggregating
a DataFrame). Alternatively, an array can be provided, but it must be the same length as the DataFrame's
Expand Down Expand Up @@ -497,7 +501,7 @@ def fast_unstack(series: pd.Series, index: pd.Index, columns: pd.Index, *, deep_
return pd.DataFrame(array, index=index, columns=columns)


def _check_disaggregation_input(mapping: pd.Series, proportions: pd.Series) -> np.ndarray:
def _check_disaggregation_input(mapping: pd.Series, proportions: pd.Series) -> NDArray:
assert mapping is not None
assert proportions is not None
assert mapping.index.equals(proportions.index)
Expand Down

0 comments on commit 84af3c6

Please sign in to comment.