From adb73c1d3c22103fc98864f924482ff9264005d5 Mon Sep 17 00:00:00 2001 From: Aadyot Bhatnagar Date: Wed, 2 Mar 2022 16:38:33 -0800 Subject: [PATCH] Implement reconciliation for hierarchical time series. (#72) * Fix some AutoSarima bugs. * Harden models to granularities like MS * Add RMSPE forecasting eval metric. * Implement min-trace reconciliation. * Fix bug for seasonality models on multivar data. * Add test for minT reconciliation. * Update docs. * Fix computation of covariance matrix. * Update version. * Add data I/O utils for hierarchical time series. * Add merlion.utils.data_io to docs. * Add data I/O test. --- docs/source/merlion.rst | 2 + docs/source/merlion.utils.rst | 57 ++++-- .../misc/generate_synthetic_tsad_dataset.py | 7 +- merlion/evaluate/forecast.py | 23 +++ merlion/models/automl/autosarima.py | 2 +- merlion/models/automl/seasonality.py | 4 +- merlion/models/ensemble/base.py | 5 +- merlion/models/forecast/trees.py | 2 +- merlion/models/layers.py | 2 +- merlion/utils/autosarima_utils.py | 5 +- merlion/utils/data_io.py | 175 ++++++++++++++++++ merlion/utils/hts.py | 91 +++++++++ merlion/utils/istat.py | 5 +- merlion/utils/misc.py | 12 +- merlion/utils/resample.py | 6 +- merlion/utils/time_series.py | 61 ++---- merlion/utils/ts_generator.py | 22 ++- setup.py | 2 +- tests/anomaly/forecast_based/test_arima.py | 7 +- tests/anomaly/forecast_based/test_lstm.py | 7 +- tests/anomaly/forecast_based/test_mses.py | 7 +- tests/anomaly/forecast_based/test_prophet.py | 5 +- tests/anomaly/forecast_based/test_sarima.py | 7 +- tests/anomaly/test_anom_ensemble.py | 9 +- tests/anomaly/test_isolation_forest.py | 6 +- tests/anomaly/test_random_cut_forest.py | 5 +- tests/anomaly/test_spectral_residual.py | 6 +- tests/anomaly/test_stat_threshold.py | 6 +- tests/anomaly/test_windstats.py | 6 +- tests/anomaly/test_zms.py | 6 +- tests/evaluate/test_eval_forecast.py | 7 +- tests/forecast/test_ets.py | 6 +- tests/forecast/test_forecast_ensemble.py | 6 +- tests/forecast/test_smoother.py | 9 +- tests/test_hts.py | 118 ++++++++++++ tests/test_plot.py | 6 +- 36 files changed, 572 insertions(+), 140 deletions(-) create mode 100644 merlion/utils/data_io.py create mode 100644 merlion/utils/hts.py create mode 100644 tests/test_hts.py diff --git a/docs/source/merlion.rst b/docs/source/merlion.rst index 70ef5aa1e..a95346c8d 100644 --- a/docs/source/merlion.rst +++ b/docs/source/merlion.rst @@ -31,6 +31,8 @@ each associated with its own sub-package: - :py:mod:`merlion.evaluate`: Evaluation metrics & pipelines to simulate the live deployment of a time series model for any task. - :py:mod:`merlion.plot`: Automated visualization of model outputs for univariate time series +- :py:mod:`merlion.utils`: Various utilities, including the `TimeSeries` class, resampling functions, + Bayesian conjugate priors, reconciliation for hierarchical time series, and more. The key classes for input and output are `merlion.utils.time_series.TimeSeries` and `merlion.utils.time_series.UnivariateTimeSeries`. Notably, these classes have transparent inter-operability diff --git a/docs/source/merlion.utils.rst b/docs/source/merlion.utils.rst index d574e411c..015c01ead 100644 --- a/docs/source/merlion.utils.rst +++ b/docs/source/merlion.utils.rst @@ -1,3 +1,4 @@ + merlion.utils package ===================== This package contains various utilities, including the `TimeSeries` class and @@ -8,44 +9,70 @@ utilities for resampling time series. :undoc-members: :show-inheritance: +.. autosummary:: + time_series + resample + data_io + hts + ts_generator + conj_priors + istat + Submodules ---------- -merlion.utils.conj_priors module --------------------------------- -.. automodule:: merlion.utils.conj_priors +merlion.utils.time\_series module +--------------------------------- + +.. automodule:: merlion.utils.time_series :members: :undoc-members: :show-inheritance: -merlion.utils.istat module --------------------------- +merlion.utils.resample module +----------------------------- -.. automodule:: merlion.utils.istat +.. automodule:: merlion.utils.resample :members: :undoc-members: :show-inheritance: -merlion.utils.misc module -------------------------- +merlion.utils.data\_io module +----------------------------- -.. automodule:: merlion.utils.misc +.. automodule:: merlion.utils.data_io :members: :undoc-members: :show-inheritance: -merlion.utils.resample module ------------------------------ -.. automodule:: merlion.utils.resample +merlion.utils.hts module +------------------------ + +.. automodule:: merlion.utils.hts :members: :undoc-members: :show-inheritance: -merlion.utils.time\_series module ---------------------------------- +merlion.utils.ts\_generator module +---------------------------------- -.. automodule:: merlion.utils.time_series +.. automodule:: merlion.utils.ts_generator + :members: + :undoc-members: + :show-inheritance: + +merlion.utils.conj_priors module +-------------------------------- +.. automodule:: merlion.utils.conj_priors + :members: + :undoc-members: + :show-inheritance: + +merlion.utils.istat module +-------------------------- + +.. automodule:: merlion.utils.istat :members: :undoc-members: :show-inheritance: diff --git a/examples/misc/generate_synthetic_tsad_dataset.py b/examples/misc/generate_synthetic_tsad_dataset.py index 0869e1868..b0f53ab3f 100644 --- a/examples/misc/generate_synthetic_tsad_dataset.py +++ b/examples/misc/generate_synthetic_tsad_dataset.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -11,7 +11,6 @@ import numpy as np from math import floor, ceil -from merlion.utils.time_series import ts_to_csv from merlion.utils.ts_generator import GeneratorConcatenator, TimeSeriesGenerator from merlion.transform.anomalize import LevelShift, Shock, TrendChange @@ -64,13 +63,13 @@ def main(): for i, ts in enumerate(ts_list): # write original ts csv = join(anom_dir, f"{ts.names[0]}.csv") - ts_to_csv(ts, csv) + ts.to_csv(csv) # anomalize ts with each anomalizer for j, (name, anom) in enumerate(anomalizers.items()): np.random.seed(1000 * i + j) anom_ts = anom(ts) csv = join(anom_dir, f"{anom_ts.names[0]}_{name}_anomaly.csv") - ts_to_csv(anom_ts, csv) + anom_ts.to_csv(csv) if __name__ == "__main__": diff --git a/merlion/evaluate/forecast.py b/merlion/evaluate/forecast.py index 9ede2bc0f..20d98c9ba 100644 --- a/merlion/evaluate/forecast.py +++ b/merlion/evaluate/forecast.py @@ -130,6 +130,23 @@ def smape(self): warnings.warn("Some values very close to 0, sMAPE might not be estimated accurately.") return np.mean(200.0 * errors / (scale + 1e-8)) + def rmspe(self): + """ + Root Mean Squared Percent Error (RMSPE) + + For ground truth time series :math:`y` and predicted time series :math:`\\hat{y}` + of length :math:`T`, it is computed as + + .. math:: 100 \\cdot \\sqrt{\\frac{1}{T}\\sum_{t=1}^T\\frac{(y_t - \\hat{y}_t)}{y_t}^2}. + """ + self.check_before_eval() + predict_values = self.predict.univariates[self.predict.names[0]].np_values + ground_truth_values = self.ground_truth.univariates[self.ground_truth.names[0]].np_values + if (ground_truth_values < 1e-8).any(): + warnings.warn("Some values very close to 0, RMSPE might not be estimated accurately.") + errors = ground_truth_values - predict_values + return 100 * np.sqrt(np.mean(np.square(errors / ground_truth_values))) + def mase(self): """ Mean Absolute Scaled Error (MASE) @@ -240,6 +257,12 @@ class ForecastMetric(Enum): 200 \\cdot \\frac{1}{T}\\sum_{t=1}^{T}{\\frac{\\left| y_t - \\hat{y}_t \\right|}{\\left| y_t \\right| + \\left| \\hat{y}_t \\right|}}. """ + RMSPE = partial(accumulate_forecast_score, metric=ForecastScoreAccumulator.rmspe) + """ + Root Mean Square Percent Error is formulated as: + + .. math:: 100 \\cdot \\sqrt{\\frac{1}{T}\\sum_{t=1}^T\\frac{(y_t - \\hat{y}_t)}{y_t}^2}. + """ MASE = partial(accumulate_forecast_score, metric=ForecastScoreAccumulator.mase) """ Mean Absolute Scaled Error (MASE) is formulated as: diff --git a/merlion/models/automl/autosarima.py b/merlion/models/automl/autosarima.py index 5b1518911..42b3ab938 100644 --- a/merlion/models/automl/autosarima.py +++ b/merlion/models/automl/autosarima.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause diff --git a/merlion/models/automl/seasonality.py b/merlion/models/automl/seasonality.py index 4c74fb3fa..27c52da44 100644 --- a/merlion/models/automl/seasonality.py +++ b/merlion/models/automl/seasonality.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -128,7 +128,7 @@ class SeasonalityLayer(AutoMLMixIn, metaclass=AutodocABCMeta): @property def require_univariate(self): - return getattr(self.config, "target_seq_index", None) is not None + return getattr(self.config, "target_seq_index", None) is None @property def multi_seasonality(self): diff --git a/merlion/models/ensemble/base.py b/merlion/models/ensemble/base.py index 1e58eabe6..f683266dc 100644 --- a/merlion/models/ensemble/base.py +++ b/merlion/models/ensemble/base.py @@ -59,7 +59,10 @@ def to_dict(self, _skipped_keys=None): if self.models is None: models = None else: - models = [None if m is None else dict(name=type(m).__name__, **m.config.to_dict()) for m in self.models] + models = [ + None if m is None else dict(name=type(m).__name__, **m.config.to_dict(_skipped_keys)) + for m in self.models + ] config_dict["models"] = models return config_dict diff --git a/merlion/models/forecast/trees.py b/merlion/models/forecast/trees.py index 92d304544..0df23cdeb 100644 --- a/merlion/models/forecast/trees.py +++ b/merlion/models/forecast/trees.py @@ -106,7 +106,7 @@ def train(self, train_data: TimeSeries, train_config=None): if self.dim == 1: logger.info( f"Model is working on a univariate dataset, " - f"hybrid of sequence and autoregression training strategy will be adopted" + f"hybrid of sequence and autoregression training strategy will be adopted " f"with prediction_stride = {self.prediction_stride} " ) if self.sampling_mode != "normal": diff --git a/merlion/models/layers.py b/merlion/models/layers.py index 3d240cafd..fc421dc72 100644 --- a/merlion/models/layers.py +++ b/merlion/models/layers.py @@ -73,7 +73,7 @@ def to_dict(self, _skipped_keys=None): if self.model is None: config_dict["model"] = None else: - config_dict["model"] = dict(name=type(self.model).__name__, **self.model.config.to_dict()) + config_dict["model"] = dict(name=type(self.model).__name__, **self.model.config.to_dict(_skipped_keys)) return config_dict @classmethod diff --git a/merlion/utils/autosarima_utils.py b/merlion/utils/autosarima_utils.py index c6bc23650..86d37b39b 100644 --- a/merlion/utils/autosarima_utils.py +++ b/merlion/utils/autosarima_utils.py @@ -1,9 +1,12 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Low-level utils for AutoML models. +""" import functools import logging import time diff --git a/merlion/utils/data_io.py b/merlion/utils/data_io.py new file mode 100644 index 000000000..120f90403 --- /dev/null +++ b/merlion/utils/data_io.py @@ -0,0 +1,175 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Utils for data I/O. +""" +from collections import OrderedDict +import inspect +from typing import Any, Dict, List, Mapping, Union + +import numpy as np +import pandas as pd + +from merlion.utils.misc import combine_signatures, parse_basic_docstring +from merlion.utils.time_series import TimeSeries + + +def df_to_time_series( + df: pd.DataFrame, + time_col: str = None, + timestamp_unit="s", + index_cols: Union[str, List[str]] = None, + data_cols: Union[str, List[str]] = None, + index_conditions: Dict[str, Any] = None, + index_agg_average=False, +) -> TimeSeries: + """ + Converts a general ``pandas.DataFrame`` to a `TimeSeries` object. + + This function allows a user to specify a hierarchical index to be aggregated over time. For example, the + dataframe may contain sales volume for multiple different stores & items, differentiated by columns ``"store_id"`` + and ``"item_id"``. In this case, you should specify ``index_cols=["store_id", "item_id"]``. + + By default, we take a simple sum of all distinct values at each timestamp at each level of the hierarchical + index. However, you may customize this behavior by specifying ``index_conditions``. Here are some examples: + + - ``{"store_id": {"vals": [1, 8]}}`` takes the total of all sales for all items in stores 1 & 8 only. + - ``{"store_id": {"vals": [1, 8], "weights": [0.5, 1.3]}, "item_id": {"weights": "price"}}`` uses the + column ``"price"`` to weight the sales of each item before summing them up (aka revenue). Then, we + weight the price-weighted sales in store 1 by 0.5 & the price-weighted sales in store 8 by 1.3, + before summing them together to obtain a total for each timestamp. + + If ``index_conditions`` is not specified for a particular non-temporal index column, we take a simple sum + of all distinct values. This is also true if ``"weights"`` is not specified for a particular index key. + + :param df: the dataframe to parse + :param time_col: the name of the column specifying time. If none is specified, the existing index is used if it + is a ``DatetimeIndex``. Otherwise, the first column is used.. + :param timestamp_unit: if the time column is in Unix timestamps, this is the unit of the timestamp. + :param index_cols: the columns to be interpreted as a hierarchical index, if desired. + :param data_cols: the columns representing the actual data values of interest. + :param index_conditions: a dict specifying how the hierarchical index should be aggregated. + :param index_agg_average: aggregate with (weighted) average if ``True``, (weighted) sum if ``False``. + """ + # Get the index columns + if index_cols is None: + index_cols = [] + elif not isinstance(index_cols, (list, tuple)): + index_cols = [index_cols] + if not all(c in df.columns for c in index_cols): + raise KeyError(f"Expected each of index_cols to be in {df.columns}. Got {index_cols}.") + + # Set up a hierarchical index for the dataframe, with the timestamp first + if time_col is None and isinstance(df.index, pd.DatetimeIndex): + df = df.set_index([df.index] + index_cols).sort_index() + if df.index.names[0] is None: + df.index.set_names("time", level=0, inplace=True) + time_col = df.index.names[0] + else: + if time_col is None: + time_col = df.columns[0] + elif time_col not in df.columns: + raise KeyError(f"Expected time_col to be in {df.columns}. Got {time_col}.") + df[time_col] = pd.to_datetime(df[time_col], unit=None if df[time_col].dtype == "O" else timestamp_unit) + df = df.set_index([time_col] + index_cols).sort_index() + + # Determine the values & weights used to restrict & aggregate the dataframe + vals_seq = [slice(None)] + weights = pd.Series(1.0, index=df.index) + index_conditions = index_conditions or {} + for c in index_cols: + cond = index_conditions.get(c, {}) + if not isinstance(cond, Mapping): + cond = {"vals": cond} + + # Determine if we're restricting the dataframe + vals = cond.get("vals", None) + if vals is None: + vals_seq.append(slice(None)) + vals = df.index.get_level_values(c).unique() + else: + vals_seq.append(vals) + + # Get the weights for this level of the aggregation + w = cond.get("weights", None) + if w is not None and isinstance(w, str): + weights *= df[w] + elif w is not None: + all_vals = df.index.get_level_values(c) + if len(w) != len(vals): + raise ValueError(f"For index column {c}, expected weights of length {len(vals)}. Got {len(w)}.") + w = pd.concat((pd.Series(w, index=vals), pd.Series(0, index=all_vals.unique().difference(vals)))) + weights *= w.loc[all_vals].values + + # Get only the desired columns from the dataframe + if data_cols is not None: + data_cols = [data_cols] if not isinstance(data_cols, (list, tuple)) else data_cols + if not all(c in df.columns for c in data_cols): + raise KeyError(f"Expected each of target_cols to be in {df.colums}. Got {data_cols}.") + df = df[data_cols] + + # Restrict & aggregate the dataframe + if len(index_cols) > 0: + ilocs = df.index.get_locs(vals_seq) + df = df.iloc[ilocs] + weights = weights.iloc[ilocs] + if index_agg_average: + df = df.groupby(time_col).agg(lambda x: np.average(x, weights=weights.loc[x.index])) + else: + df = (df * weights.values.reshape(-1, 1)).groupby(time_col).sum() + + # Convert the dataframe to a time series & return it + return TimeSeries.from_pd(df) + + +def data_io_decorator(func): + """ + Decorator to standardize docstrings for data I/O functions. + """ + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Parse the docstrings of the base df_to_time_series function & decorated function. + prefix, suffix, params = parse_basic_docstring(func.__doc__) + base_prefix, base_suffix, base_params = parse_basic_docstring(df_to_time_series.__doc__) + + # Combine the prefixes. Base prefix starts after the first line break. + i_lb = [i for i, line in enumerate(base_prefix) if line == ""][1] + prefix = ("\n".join(prefix) if any([line != "" for line in prefix]) else "") + "\n".join(base_prefix[i_lb:]) + + # The base docstring has no suffix, so just use the function's + suffix = "\n".join(suffix) if any([line != "" for line in suffix]) else "" + + # Combine the parameter lists + for param, docstring_lines in base_params.items(): + if param not in params: + params[param] = "\n".join(docstring_lines).rstrip("\n") + + # Combine the signatures, but remove some parameters that are specific to the original (as well as kwargs). + new_sig_params = [] + sig = combine_signatures(inspect.signature(func), inspect.signature(df_to_time_series)) + for param in sig.parameters.values(): + if param.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}: + break + if param.name not in ["df"]: + new_sig_params.append(param) + sig = sig.replace(parameters=new_sig_params) + + # Update the signature and docstring of the wrapper we are returning. Use only the params in the new signature. + wrapper.__signature__ = sig + params = OrderedDict((p, params[p]) for p in sig.parameters if p in params) + wrapper.__doc__ = (prefix or "") + "\n" + "\n".join(params.values()) + "\n\n" + (suffix or "") + return wrapper + + +@data_io_decorator +def csv_to_time_series(file_name: str, **kwargs) -> TimeSeries: + """ + Reads a CSV file and converts it to a `TimeSeries` object. + """ + return df_to_time_series(pd.read_csv(file_name), **kwargs) diff --git a/merlion/utils/hts.py b/merlion/utils/hts.py new file mode 100644 index 000000000..3901e0451 --- /dev/null +++ b/merlion/utils/hts.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Aggregation for hierarchical time series. +""" +from collections import OrderedDict +from typing import List + +import numpy as np +import pandas as pd + +from merlion.utils.time_series import TimeSeries, to_pd_datetime + + +def minT_reconciliation( + forecasts: List[TimeSeries], errs: List[TimeSeries], sum_matrix: np.ndarray, n_leaves: int +) -> List[TimeSeries]: + """ + Computes the minimum trace reconciliation for hierarchical time series, as described by + `Wickramasuriya et al. 2018 `__. This algorithm assumes that + we have a number of time series aggregated at various levels (the aggregation tree is described by ``sum_matrix``), + and we obtain independent forecasts at each level of the hierarchy. Minimum trace reconciliation finds the optimal + way to adjust (reconcile) the forecasts to reduce the variance of the estimation. + + :param forecasts: forecast for each aggregation level of the hierarchy + :param errs: standard errors of forecasts for each level of the hierarchy. While not strictly necessary, + reconciliation performs better if all forecasts are accompanied by uncertainty estimates. + :param sum_matrix: matrix describing how the hierarchy is aggregated + :param n_leaves: the number of leaf forecasts (i.e. the number of forecasts at the most dis-aggregated level + of the hierarchy). We assume that the leaf forecasts are last in the lists ``forecasts`` & ``errs``, + and that ``sum_matrix`` reflects this fact. + + :return: reconciled forecasts for each aggregation level of the hierarchy + """ + m = len(forecasts) + n = n_leaves + assert len(errs) == m > n + assert all(yhat.dim == 1 for yhat in forecasts) + assert sum_matrix.shape == (m, n), f"Expected sum_matrix to have shape ({m}, {n}) got {sum_matrix.shape}" + assert (sum_matrix[-n:] == np.eye(n)).all() + + # Convert forecasts to a single aligned multivariate time series + names = [yhat.names[0] for yhat in forecasts] + forecasts = OrderedDict((i, yhat.univariates[yhat.names[0]]) for i, yhat in enumerate(forecasts)) + forecasts = TimeSeries(univariates=forecasts).align() + t_ref = forecasts.time_stamps + H = len(forecasts) + + # Matrix of stderrs (if any) at each prediction horizon. shape is [m, H]. + # If no stderrs are given, we the estimation error is proportional to the number of leaf nodes being combined. + coefs = sum_matrix.sum(axis=1) + if all(e is None for e in errs): + # FIXME: This heuristic can be improved if training errors are given. + # However, the model code should probably be responsible for this, not the reconciliation code. + Wh = [np.diag(coefs) for _ in range(H)] + else: + coefs = coefs.reshape(-1, 1) + errs = np.asarray( + [np.full(H, np.nan) if e is None else e.align(reference=t_ref).to_pd().values.flatten() ** 2 for e in errs] + ) # [m, H] + # Replace NaN's w/ the mean of non-NaN stderrs & create diagonal error matrices + nan_errs = np.isnan(errs[:, 0]) + if nan_errs.any(): + errs[nan_errs] = np.nanmean(errs / coefs, axis=0) * coefs[nan_errs] + Wh = [np.diag(errs[:, h]) for h in range(H)] + + # Create other supplementary matrices + J = np.zeros((n, m)) + J[:, -n:] = np.eye(n) + U = np.zeros((m - n, m)) + U[:, : m - n] = np.eye(m - n) + U[:, m - n :] = -sum_matrix[:-n] + + # Compute projection matrices to compute coherent leaf forecasts + Ph = [] + for W in Wh: + inv = np.linalg.inv(U @ W @ U.T) + P = J - ((J @ W) @ U.T) @ (inv @ U) + Ph.append(P) + + # Compute reconciled forecasts + reconciled = [] + for (t, yhat_h), P in zip(forecasts, Ph): + reconciled.append(sum_matrix @ (P @ yhat_h)) + reconciled = pd.DataFrame(np.asarray(reconciled), index=to_pd_datetime(t_ref)) + + return [u.to_ts(name=name) for u, name in zip(TimeSeries.from_pd(reconciled).univariates, names)] diff --git a/merlion/utils/istat.py b/merlion/utils/istat.py index 6b50bfcc0..cb792e460 100644 --- a/merlion/utils/istat.py +++ b/merlion/utils/istat.py @@ -1,9 +1,12 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Incremental computation of time series statistics. +""" from abc import abstractmethod from typing import List from math import sqrt diff --git a/merlion/utils/misc.py b/merlion/utils/misc.py index 58e8a7899..7829a62ac 100644 --- a/merlion/utils/misc.py +++ b/merlion/utils/misc.py @@ -1,9 +1,12 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Miscellaneous low-level utilities (not for end users). +""" from abc import ABCMeta from collections import OrderedDict from copy import deepcopy @@ -63,7 +66,7 @@ def __new__(mcs, classname, bases, cls_dict): sig = combine_signatures(sig, inspect.signature(cls_.__init__)) # Parse the __init__ docstring. Use the earliest prefix/param docstring in the MRO. - prefix_, suffix_, params_ = parse_init_docstring(cls_.__init__.__doc__) + prefix_, suffix_, params_ = parse_basic_docstring(cls_.__init__.__doc__) if prefix is None and any([line != "" for line in prefix_]): prefix = "\n".join(prefix_) if suffix is None and any([line != "" for line in suffix_]): @@ -108,7 +111,10 @@ def combine_signatures(sig1: Union[inspect.Signature, None], sig2: Union[inspect return sig1.replace(parameters=params) -def parse_init_docstring(docstring): +def parse_basic_docstring(docstring): + """ + Parse the docstring of a model config's ``__init__``, or other basic docstring. + """ docstring_lines = [""] if docstring is None else docstring.split("\n") prefix, suffix, param_dict = [], [], OrderedDict() non_empty_lines = [line for line in docstring_lines if len(line) > 0] diff --git a/merlion/utils/resample.py b/merlion/utils/resample.py index 1e19ec915..fd5a80261 100644 --- a/merlion/utils/resample.py +++ b/merlion/utils/resample.py @@ -4,6 +4,9 @@ # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Code for resampling time series. +""" from enum import Enum from functools import partial import logging @@ -58,8 +61,7 @@ class MissingValuePolicy(Enum): def to_pd_datetime(timestamp): """ - Converts a timestamp (or list/iterable of timestamps) to pandas Datetime, - truncated at the millisecond. + Converts a timestamp (or list/iterable of timestamps) to pandas Datetime, truncated at the millisecond. """ if isinstance(timestamp, pd.DatetimeIndex): return timestamp diff --git a/merlion/utils/time_series.py b/merlion/utils/time_series.py index 35d152cdf..dadb1fd48 100644 --- a/merlion/utils/time_series.py +++ b/merlion/utils/time_series.py @@ -1,9 +1,12 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Implementation of `TimeSeries` class. +""" from bisect import bisect_left, bisect_right import itertools import logging @@ -26,6 +29,7 @@ ) logger = logging.getLogger(__name__) +_time_col_name = "time" class UnivariateTimeSeries(pd.Series): @@ -110,6 +114,7 @@ def __init__( super().__init__(np.asarray(values), index=index, name=name, dtype=float) if len(self) >= 3 and self.index.freq is None: self.index.freq = pd.infer_freq(self.index) + self.index.name = _time_col_name @property def np_time_stamps(self): @@ -293,15 +298,17 @@ def from_pd(cls, series: pd.Series, name=None, freq="1h"): """ return cls(time_stamps=None, values=series.astype(float), name=name, freq=freq) - def to_ts(self): + def to_ts(self, name=None): """ + :name: a name to assign the univariate when converting it to a time series. Can override the existing name. :rtype: TimeSeries :return: A `TimeSeries` representing this univariate time series. """ - if self.name is None: + if self.name is None and name is None: return TimeSeries([self]) else: - return TimeSeries({self.name: self}) + name = name if self.name is None else self.name + return TimeSeries({name: self}) @classmethod def empty(cls, name=None): @@ -709,11 +716,15 @@ def to_pd(self) -> pd.DataFrame: for _, var in univariates: t = t.union(var.index) t = t.sort_values() + t.name = _time_col_name df = pd.DataFrame(np.full((len(t), len(univariates)), np.nan), index=t, columns=self.names) for name, var in univariates: df.loc[var.index, name] = var[~var.index.duplicated()] return df + def to_csv(self, file_name): + self.to_pd().to_csv(file_name) + @classmethod def from_pd(cls, df: Union[pd.Series, pd.DataFrame, np.ndarray], check_times=True, freq="1h"): """ @@ -953,48 +964,6 @@ def align( raise RuntimeError(f"Alignment policy {alignment_policy.name} not supported") -def ts_csv_load(file_name: str, ms=True, n_vars=None) -> TimeSeries: - """ - :param file_name: a csv file starting with the field timestamp followed by - all the all variable names. - :param ms: whether the timestamps are in milliseconds (rather than seconds) - :return: A merlion `TimeSeries` object. - """ - with open(file_name, "r") as f: - header = True - for line in f: - if header: - header = False - names = line.strip().split(",")[1:] - vars = {name: [] for name in names} - stamps = [] - continue - if not line: - continue - words = line.strip().split(",") - stamp, vals = int(words[0]), words[1:] - if ms: - stamp = stamp / 1000 - stamps += [stamp] - for name, val in zip(names, vals): - vars[name] += [float(val)] - - return TimeSeries([UnivariateTimeSeries(stamps, vals, name) for name, vals in vars.items()][:n_vars]) - - -def ts_to_csv(time_series: TimeSeries, file_name: str): - """ - :param time_series: the `TimeSeries` object to write to a csv. - :param file_name: the name to assign the csv file. - """ - with open(file_name, "w") as f: - header = ",".join(["timestamp"] + time_series.names) - f.write(f"{header}\n") - for t, x in time_series: - vals = ",".join([str(v) for v in (int(t),) + x]) - f.write(f"{vals}\n") - - def assert_equal_timedeltas(time_series: UnivariateTimeSeries, timedelta: float = None): """ Checks that all time deltas in the time series are equal, either to each diff --git a/merlion/utils/ts_generator.py b/merlion/utils/ts_generator.py index bc4a10bf8..554186031 100644 --- a/merlion/utils/ts_generator.py +++ b/merlion/utils/ts_generator.py @@ -1,9 +1,12 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +Generators for synthetic time series. +""" import numpy as np import pandas as pd @@ -17,6 +20,8 @@ class TimeSeriesGenerator: """ An abstract base class for generating synthetic time series data. + Generates a 1-dimensional grid x(0), x(1), ..., x(n-1), where x(i) = x0 + i * step. + Then generates a time series y(0), y(1), ..., y(n-1), where y(i) = f(x(i)) + noise. """ def __init__( @@ -46,9 +51,6 @@ def __init__( TimeSeries object. :param tdelta: the time delta to use when wrapping the generated values into a TimeSeries object. - - Generates a 1-dimensional grid x(0), x(1), ..., x(n-1), where x(i) = x0 + i * step. - Then generates a time series y(0), y(1), ..., y(n-1), where y(i) = f(x(i)) + noise. """ assert step > 0, f"step must be a postive real number but is {step}." assert scale > 0, f"scale must be a postive real number but is {scale}." @@ -163,6 +165,12 @@ class GeneratorConcatenator(GeneratorComposer): fundamental changes to it's behavior that certain points in time. For example, with this class one could generate a time series that begins as linear and then becomes stationary. + + For example, let f = 0 with for 3 steps 0,1,2 and g = 2 * x for the next three + steps 3,4,5. generate() returns: + + - [0, 0, 0, 6, 8, 10] if string_outputs is False + - [0, 0, 0, 2, 4, 6] if string_outputs is True. """ def __init__(self, string_outputs: bool = True, **kwargs): @@ -172,12 +180,6 @@ def __init__(self, string_outputs: bool = True, **kwargs): two generating functions f, and g belonging to consecutive generators. If True, adjust g by a constant c such that f(x) = g(x) at the last point x that f uses to generate its series. - - For example, let f = 0 with for 3 steps 0,1,2 and g = 2 * x for the next three - steps 3,4,5. generate() returns: - [0, 0, 0, 6, 8, 10] if string_outputs is False - [0, 0, 0, 2, 4, 6] if string_outputs is True. - """ kwargs["f"] = None kwargs["n"] = 1 diff --git a/setup.py b/setup.py index 161fdbe68..844c096c0 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def read_file(fname): setup( name="salesforce-merlion", - version="1.1.1", + version="1.1.2", author=", ".join(read_file("AUTHORS.md").split("\n")), author_email="abhatnagar@salesforce.com", description="Merlion: A Machine Learning Framework for Time Series Intelligence", diff --git a/tests/anomaly/forecast_based/test_arima.py b/tests/anomaly/forecast_based/test_arima.py index dc76099b8..1b1151d8f 100644 --- a/tests/anomaly/forecast_based/test_arima.py +++ b/tests/anomaly/forecast_based/test_arima.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -14,7 +14,8 @@ from merlion.transform.resample import TemporalResample from merlion.models.anomaly.forecast_based.arima import ArimaDetector, ArimaDetectorConfig -from merlion.utils.time_series import ts_csv_load, TimeSeries +from merlion.utils.time_series import TimeSeries +from merlion.utils.data_io import csv_to_time_series logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(dirname(abspath(__file__))))) @@ -26,7 +27,7 @@ def __init__(self, *args, **kwargs): # Re-sample to 15min because the default (1min) takes too long to train self.csv_name = join(rootdir, "data", "example.csv") - data = ts_csv_load(self.csv_name, n_vars=1) + data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{data[:5]}") self.test_len = math.ceil(len(data) / 5) diff --git a/tests/anomaly/forecast_based/test_lstm.py b/tests/anomaly/forecast_based/test_lstm.py index 11aadedb9..b259062e2 100644 --- a/tests/anomaly/forecast_based/test_lstm.py +++ b/tests/anomaly/forecast_based/test_lstm.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -17,7 +17,8 @@ from merlion.models.anomaly.forecast_based.lstm import LSTMDetector, LSTMTrainConfig, LSTMDetectorConfig from merlion.models.forecast.lstm import auto_stride from merlion.post_process.threshold import AggregateAlarms -from merlion.utils.time_series import ts_csv_load, TimeSeries +from merlion.utils.time_series import TimeSeries +from merlion.utils.data_io import csv_to_time_series logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(dirname(abspath(__file__))))) @@ -27,7 +28,7 @@ class TestLSTM(unittest.TestCase): def test_full(self): file_name = join(rootdir, "data", "example.csv") - sequence = TemporalResample("15min")(ts_csv_load(file_name, n_vars=1)) + sequence = TemporalResample("15min")(csv_to_time_series(file_name, timestamp_unit="ms", data_cols=["kpi"])) logger.info(f"Data looks like:\n{sequence[:5]}") time_stamps = sequence.univariates[sequence.names[0]].time_stamps diff --git a/tests/anomaly/forecast_based/test_mses.py b/tests/anomaly/forecast_based/test_mses.py index c716d2894..803159db7 100644 --- a/tests/anomaly/forecast_based/test_mses.py +++ b/tests/anomaly/forecast_based/test_mses.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -14,7 +14,8 @@ from merlion.transform.resample import TemporalResample from merlion.models.anomaly.forecast_based.mses import MSESDetector, MSESDetectorConfig -from merlion.utils.time_series import ts_csv_load, TimeSeries +from merlion.utils.data_io import csv_to_time_series +from merlion.utils.time_series import TimeSeries logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(dirname(abspath(__file__))))) @@ -26,7 +27,7 @@ def __init__(self, *args, **kwargs): # Re-sample to 1hr because default (1min) takes too long to train self.csv_name = join(rootdir, "data", "example.csv") - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{self.data[:5]}") self.test_len = math.ceil(len(self.data) / 10) diff --git a/tests/anomaly/forecast_based/test_prophet.py b/tests/anomaly/forecast_based/test_prophet.py index 825340b82..a28fb873a 100644 --- a/tests/anomaly/forecast_based/test_prophet.py +++ b/tests/anomaly/forecast_based/test_prophet.py @@ -15,7 +15,8 @@ from merlion.models.automl.autoprophet import AutoProphet from merlion.models.anomaly.forecast_based.prophet import ProphetDetector, ProphetDetectorConfig -from merlion.utils.time_series import ts_csv_load, TimeSeries +from merlion.utils.data_io import csv_to_time_series +from merlion.utils.time_series import TimeSeries from merlion.transform.normalize import PowerTransform from merlion.transform.resample import TemporalResample @@ -28,7 +29,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") - self.data = TemporalResample("15min")(ts_csv_load(self.csv_name, n_vars=1)) + self.data = TemporalResample("15min")(csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"])) logger.info(f"Data looks like:\n{self.data[:5]}") holidays = pd.DataFrame({"ds": ["03-17-2020"], "holiday": ["St. Patrick's Day"]}) diff --git a/tests/anomaly/forecast_based/test_sarima.py b/tests/anomaly/forecast_based/test_sarima.py index dd09259ee..cc65fae5c 100644 --- a/tests/anomaly/forecast_based/test_sarima.py +++ b/tests/anomaly/forecast_based/test_sarima.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -14,7 +14,8 @@ from merlion.models.anomaly.forecast_based.sarima import SarimaDetector, SarimaDetectorConfig from merlion.transform.resample import TemporalResample -from merlion.utils.time_series import ts_csv_load, TimeSeries +from merlion.utils.data_io import csv_to_time_series +from merlion.utils.time_series import TimeSeries logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(dirname(abspath(__file__))))) @@ -26,7 +27,7 @@ def __init__(self, *args, **kwargs): # Re-sample to 1hr because default (1min) takes too long to train self.csv_name = join(rootdir, "data", "example.csv") - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{self.data[:5]}") self.test_len = math.ceil(len(self.data) / 5) diff --git a/tests/anomaly/test_anom_ensemble.py b/tests/anomaly/test_anom_ensemble.py index 9eb1cbaa8..16b2d98cc 100644 --- a/tests/anomaly/test_anom_ensemble.py +++ b/tests/anomaly/test_anom_ensemble.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -15,9 +15,8 @@ from merlion.models.ensemble.anomaly import DetectorEnsemble, DetectorEnsembleConfig from merlion.models.ensemble.combine import Mean, Median from merlion.models.factory import ModelFactory -from merlion.post_process.threshold import AggregateAlarms from merlion.transform.resample import TemporalResample -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(abspath(__file__)))) @@ -28,7 +27,7 @@ class TestMedianAnomEnsemble(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # load the time series sequence [(t1,v1), (t2, v2),...] - data = ts_csv_load(csv_name, n_vars=1) + data = csv_to_time_series(csv_name, timestamp_unit="ms", data_cols=["kpi"]) # split the sequence into train and test self.vals_train = data[:-32768] @@ -89,7 +88,7 @@ class TestMeanAnomEnsemble(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # load the time series sequence [(t1,v1), (t2, v2),...] - data = ts_csv_load(csv_name, n_vars=1) + data = csv_to_time_series(csv_name, timestamp_unit="ms", data_cols=["kpi"]) # split the sequence into train and test self.vals_train = data[:-32768] diff --git a/tests/anomaly/test_isolation_forest.py b/tests/anomaly/test_isolation_forest.py index 0067bdefc..89a060549 100644 --- a/tests/anomaly/test_isolation_forest.py +++ b/tests/anomaly/test_isolation_forest.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -15,7 +15,7 @@ from merlion.transform.moving_average import MovingAverage, ExponentialMovingAverage from merlion.transform.resample import Shingle from merlion.transform.sequence import TransformSequence -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 32768 - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{self.data[:5]}") self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] diff --git a/tests/anomaly/test_random_cut_forest.py b/tests/anomaly/test_random_cut_forest.py index daf68a5b4..7ad6e0665 100644 --- a/tests/anomaly/test_random_cut_forest.py +++ b/tests/anomaly/test_random_cut_forest.py @@ -7,7 +7,6 @@ import logging import math from os.path import abspath, dirname, join -import pytest import sys import unittest @@ -19,7 +18,7 @@ from merlion.transform.normalize import MeanVarNormalize from merlion.transform.resample import Shingle, TemporalResample from merlion.transform.sequence import TransformSequence -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -29,7 +28,7 @@ class TestRandomCutForest(unittest.TestCase): def run_init(self): # Resample @ 5min granularity b/c default (1min) takes too long to train self.csv_name = join(rootdir, "data", "example.csv") - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) self.test_len = math.ceil(len(self.data) / 5) logger.info(f"Data looks like:\n{self.data[:5]}") diff --git a/tests/anomaly/test_spectral_residual.py b/tests/anomaly/test_spectral_residual.py index ff0bb5dce..d6975ac05 100644 --- a/tests/anomaly/test_spectral_residual.py +++ b/tests/anomaly/test_spectral_residual.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -13,7 +13,7 @@ from merlion.models.anomaly.spectral_residual import SpectralResidual, SpectralResidualConfig from merlion.post_process.threshold import AggregateAlarms -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 32768 - self.data = ts_csv_load(self.csv_name) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols="kpi") logger.info(f"Data looks like:\n{self.data[:5]}") self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] diff --git a/tests/anomaly/test_stat_threshold.py b/tests/anomaly/test_stat_threshold.py index a3d3d5ed5..ea8ef74c0 100644 --- a/tests/anomaly/test_stat_threshold.py +++ b/tests/anomaly/test_stat_threshold.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -13,7 +13,7 @@ from merlion.models.anomaly.stat_threshold import StatThreshold, StatThresholdConfig from merlion.post_process.threshold import AggregateAlarms -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 32768 - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{self.data[:5]}") self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] diff --git a/tests/anomaly/test_windstats.py b/tests/anomaly/test_windstats.py index f7a992e8d..f8f6ca930 100644 --- a/tests/anomaly/test_windstats.py +++ b/tests/anomaly/test_windstats.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -13,7 +13,7 @@ from merlion.models.anomaly.windstats import WindStatsConfig, WindStats from merlion.post_process.threshold import AggregateAlarms -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 32768 - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like:\n{self.data[:5]}") self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] diff --git a/tests/anomaly/test_zms.py b/tests/anomaly/test_zms.py index 5c223e4fd..39dcf7cb3 100644 --- a/tests/anomaly/test_zms.py +++ b/tests/anomaly/test_zms.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -9,9 +9,9 @@ import sys import unittest -from merlion.utils.time_series import ts_csv_load from merlion.models.anomaly.zms import ZMS, ZMSConfig from merlion.post_process.threshold import AggregateAlarms +from merlion.utils.data_io import csv_to_time_series rootdir = dirname(dirname(dirname(abspath(__file__)))) logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 32768 - self.data = ts_csv_load(self.csv_name, n_vars=1) + self.data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"]) logger.info(f"Data looks like: {self.data[:5]}") self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] diff --git a/tests/evaluate/test_eval_forecast.py b/tests/evaluate/test_eval_forecast.py index 7a573ef0b..a7e480bf7 100644 --- a/tests/evaluate/test_eval_forecast.py +++ b/tests/evaluate/test_eval_forecast.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -15,7 +15,8 @@ from merlion.models.ensemble.forecast import ForecasterEnsemble, ForecasterEnsembleConfig from merlion.models.forecast.arima import ArimaConfig, Arima from merlion.transform.base import Identity -from merlion.utils.time_series import UnivariateTimeSeries, ts_csv_load +from merlion.utils.data_io import csv_to_time_series +from merlion.utils.time_series import UnivariateTimeSeries logger = logging.getLogger(__name__) @@ -64,7 +65,7 @@ def test_ensemble(self): logger.info("test_ensemble\n" + "-" * 80 + "\n") csv_name = join(rootdir, "data", "example.csv") - ts = ts_csv_load(csv_name, ms=True, n_vars=1).align(granularity="1h") + ts = csv_to_time_series(csv_name, timestamp_unit="ms", data_cols=["kpi"]).align(granularity="1h") n_test = len(ts) // 5 train, test = ts[:-n_test], ts[-n_test:] diff --git a/tests/forecast/test_ets.py b/tests/forecast/test_ets.py index 5d2371e68..99ad3bbdb 100644 --- a/tests/forecast/test_ets.py +++ b/tests/forecast/test_ets.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -108,6 +108,10 @@ def test_forecast(self): rmse = ForecastMetric.RMSE.value(self.test_data, forecast) logger.info(f"RMSE = {rmse:.4f} for {self.max_forecast_steps} step forecasting") self.assertAlmostEqual(rmse, 6.5, delta=1) + rmspe = ForecastMetric.RMSPE.value(self.test_data, forecast) + logger.info(f"RMPSE = {rmspe:.4f} for {self.max_forecast_steps} step forecasting") + smape = ForecastMetric.sMAPE.value(self.test_data, forecast) + logger.info(f"sMAPE = {smape:.4f} for {self.max_forecast_steps} step forecasting") msis = ForecastMetric.MSIS.value( ground_truth=self.test_data, predict=forecast, insample=self.train_data, periodicity=4, ub=ub, lb=lb ) diff --git a/tests/forecast/test_forecast_ensemble.py b/tests/forecast/test_forecast_ensemble.py index 59491f4ec..1d2af1023 100644 --- a/tests/forecast/test_forecast_ensemble.py +++ b/tests/forecast/test_forecast_ensemble.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -19,7 +19,7 @@ from merlion.models.factory import ModelFactory from merlion.transform.base import Identity from merlion.transform.resample import TemporalResample -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(abspath(__file__)))) @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") self.test_len = 2048 - data = ts_csv_load(self.csv_name, n_vars=1)[::10] + data = csv_to_time_series(self.csv_name, timestamp_unit="ms", data_cols=["kpi"])[::10] self.vals_train = data[: -self.test_len] self.vals_test = data[-self.test_len :].univariates[data.names[0]] diff --git a/tests/forecast/test_smoother.py b/tests/forecast/test_smoother.py index 4f98a4456..7ba860a69 100644 --- a/tests/forecast/test_smoother.py +++ b/tests/forecast/test_smoother.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -14,9 +14,10 @@ import numpy as np from numpy.core.fromnumeric import mean -from merlion.transform.resample import TemporalResample -from merlion.utils.time_series import UnivariateTimeSeries, ts_csv_load from merlion.models.forecast.smoother import MSES, MSESConfig, MSESTrainConfig +from merlion.transform.resample import TemporalResample +from merlion.utils.data_io import csv_to_time_series +from merlion.utils.time_series import UnivariateTimeSeries logger = logging.getLogger(__name__) rootdir = dirname(dirname(dirname(abspath(__file__)))) @@ -27,7 +28,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) csv_name = join(rootdir, "data", "example.csv") - self.data = TemporalResample("1h")(ts_csv_load(csv_name, n_vars=1)) + self.data = TemporalResample("1h")(csv_to_time_series(csv_name, timestamp_unit="ms", data_cols="kpi")) logger.info(f"Data looks like: {self.data[:5]}") n = math.ceil(len(self.data) / 5) diff --git a/tests/test_hts.py b/tests/test_hts.py new file mode 100644 index 000000000..1eee793bd --- /dev/null +++ b/tests/test_hts.py @@ -0,0 +1,118 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import logging +from os.path import abspath, dirname, join +import sys +import unittest + +import numpy as np +import pandas as pd + +from merlion.evaluate.forecast import ForecastMetric +from merlion.models.factory import ModelFactory +from merlion.transform.normalize import MinMaxNormalize +from merlion.transform.sequence import TransformSequence +from merlion.transform.resample import TemporalResample +from merlion.transform.bound import LowerUpperClip +from merlion.transform.moving_average import DifferenceTransform +from merlion.utils import TimeSeries, UnivariateTimeSeries +from merlion.utils.data_io import df_to_time_series +from merlion.utils.hts import minT_reconciliation +from ts_datasets.forecast import SeattleTrail + +logger = logging.getLogger(__name__) +rootdir = dirname(dirname(abspath(__file__))) + + +class TestHTS(unittest.TestCase): + """ + we test data loading, model instantiation, forecasting consistency, in particular + (1) load a testing data + (2) transform data + (3) instantiate the model and train + (4) forecast, and the forecasting result agrees with the reference + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.max_forecast_steps = 2 + self.maxlags = 6 + self.i = 0 + # t = int(datetime(2019, 1, 1, 0, 0, 0).timestamp()) + + dataset = "seattle_trail" + d, md = SeattleTrail(rootdir=join(rootdir, "data", "multivariate", dataset))[0] + t = int(d[md["trainval"]].index[-1].to_pydatetime().timestamp()) + data = TimeSeries.from_pd(d) + cleanup_transform = TransformSequence( + [TemporalResample(missing_value_policy="FFill"), LowerUpperClip(upper=300), DifferenceTransform()] + ) + cleanup_transform.train(data) + data = cleanup_transform(data) + + train_data, test_data = data.bisect(t) + + h = 100 + minmax_transform = MinMaxNormalize() + minmax_transform.train(train_data) + self.train_data_norm = minmax_transform(train_data[-2000:]) + self.test_data_norm = minmax_transform(test_data[:h]) + self.train_data_agg = UnivariateTimeSeries.from_pd(self.train_data_norm.to_pd().sum(axis=1), name="val").to_ts() + self.test_data_agg = UnivariateTimeSeries.from_pd(self.test_data_norm.to_pd().sum(axis=1), name="val").to_ts() + + self.models = [ModelFactory.create("AutoETS", target_seq_index=i) for i in range(test_data.dim)] + self.agg_model = ModelFactory.create("LGBMForecaster", max_forecast_steps=h, maxlags=100) + + def test_minT(self): + print("=" * 80) + logger.info("test_minT" + "\n" + "=" * 80) + logger.info("Training models...") + forecasts, errs = [], [] + models = [self.agg_model, *self.models] + train_data = [self.train_data_agg] + [self.train_data_norm] * len(self.models) + test_data = [self.test_data_agg] + [self.test_data_norm] * len(self.models) + for model, train, test in zip(models, train_data, test_data): + model.train(train) + forecast, err = model.forecast(test.time_stamps) + forecasts.append(forecast) + errs.append(None if len(errs) == 1 else err) + + logger.info("Applying reconciliation...") + sum_matrix = np.concatenate([np.ones((1, len(self.models))), np.eye(len(self.models))]) + reconciled = minT_reconciliation(forecasts, errs, sum_matrix=sum_matrix, n_leaves=len(self.models)) + + naive_sum = np.sum([f.to_pd().values.flatten() for f in forecasts[1:]]) + naive_sum = UnivariateTimeSeries(time_stamps=self.test_data_agg.time_stamps, values=naive_sum).to_ts() + naive = ForecastMetric.RMSE.value(predict=naive_sum, ground_truth=self.test_data_agg) + direct = ForecastMetric.RMSE.value(predict=forecasts[0], ground_truth=self.test_data_agg) + minT = ForecastMetric.RMSE.value(predict=reconciled[0], ground_truth=self.test_data_agg) + + logger.info(f"Naive summation RMSE: {naive:.2f}") + logger.info(f"Direct prediction RMSE: {direct:.4f}") + logger.info(f"minT reconciliation RMSE: {minT:.4f}") + self.assertLess(direct, naive) + self.assertLess(minT, direct) + + def test_data_io(self): + print("=" * 80) + logger.info("test_data_io" + "\n" + "=" * 80) + df = self.train_data_norm.to_pd() + df_hierarchical = pd.concat( + [pd.DataFrame({"name": c, "val": df[c].values}, index=df.index) for c in df.columns] + ) + df_hierarchical.index.name = None + ts_agg = df_to_time_series(df_hierarchical, index_cols=["name"]) + max_delta = (ts_agg.to_pd() - self.train_data_agg.to_pd()).abs().max().item() + self.assertLessEqual(max_delta, 1e-8) + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.INFO + ) + unittest.main() diff --git a/tests/test_plot.py b/tests/test_plot.py index a4a499e16..26e24bc30 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -15,7 +15,7 @@ from merlion.transform.resample import TemporalResample from merlion.models.anomaly.forecast_based.prophet import ProphetDetector, ProphetDetectorConfig from merlion.plot import plot_anoms, plot_anoms_plotly -from merlion.utils.time_series import ts_csv_load +from merlion.utils.data_io import csv_to_time_series logger = logging.getLogger(__name__) rootdir = dirname(dirname(abspath(__file__))) @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_name = join(rootdir, "data", "example.csv") - data = ts_csv_load(self.csv_name) + data = csv_to_time_series(self.csv_name, timestamp_unit="ms") self.data = TemporalResample("15min")(data.univariates[data.names[0]].to_ts()) self.labels = data.univariates[data.names[1]].to_ts() logger.info(f"Data looks like:\n{self.data[:5]}")