diff --git a/benchmark_anomaly.py b/benchmark_anomaly.py index f9ffbed39..dce2524d4 100644 --- a/benchmark_anomaly.py +++ b/benchmark_anomaly.py @@ -61,6 +61,8 @@ def parse_args(): "in ts_datasets/ts_datasets/anomaly/__init__.py for " "valid options.", ) + parser.add_argument("--data_root", default=None, help="Root directory/file of dataset.") + parser.add_argument("--data_kwargs", default="{}", help="JSON of keyword arguemtns for the data loader.") parser.add_argument( "--models", type=str, @@ -148,6 +150,8 @@ def parse_args(): args.metric = TSADMetric[args.metric] args.pointwise_metric = TSADMetric[args.pointwise_metric] args.visualize = args.visualize and not args.eval_only + args.data_kwargs = json.loads(args.data_kwargs) + assert isinstance(args.data_kwargs, dict) if args.retrain_freq.lower() in ["", "none", "null"]: args.retrain_freq = None elif args.retrain_freq != "default": @@ -164,10 +168,14 @@ def parse_args(): return args -def dataset_to_name(dataset: TSADBaseDataset): - if dataset.subset is not None: - return f"{type(dataset).__name__}_{dataset.subset}" - return type(dataset).__name__ +def get_dataset_name(dataset: TSADBaseDataset): + name = type(dataset).__name__ + if hasattr(dataset, "subset") and dataset.subset is not None: + name += "_" + dataset.subset + if isinstance(dataset, CustomAnomalyDataset): + root = dataset.rootdir + name = os.path.join(name, os.path.basename(os.path.dirname(root) if os.path.isfile(root) else root)) + return name def dataset_to_threshold(dataset: TSADBaseDataset, tune_on_test=False): @@ -269,14 +277,13 @@ def train_model( unsupervised=False, tune_on_test=False, ): - """Trains a model on the time series dataset given, and save their predictions - to a dataset.""" + """Trains a model on the time series dataset given, and save their predictions to a dataset.""" resampler = None if isinstance(dataset, IOpsCompetition): resampler = TemporalResample("5min") model_name = resolve_model_name(model_name) - dataset_name = dataset_to_name(dataset) + dataset_name = get_dataset_name(dataset) model_dir = model_name if retrain_freq is None else f"{model_name}_{retrain_freq}" dirname = os.path.join("results", "anomaly", model_dir) csv = os.path.join(dirname, f"pred_{dataset_name}.csv.gz") @@ -337,16 +344,17 @@ def train_model( if not visualize: if i == i0 == 0: os.makedirs(os.path.dirname(csv), exist_ok=True) + os.makedirs(os.path.dirname(checkpoint), exist_ok=True) df = pd.DataFrame({"timestamp": [], "y": [], "trainval": [], "idx": []}) df.to_csv(csv, index=False) df = pd.read_csv(csv) - ts_df = train_scores.to_pd().append(test_scores.to_pd()) + ts_df = pd.concat((train_scores.to_pd(), test_scores.to_pd())) ts_df.columns = ["y"] ts_df.loc[:, "timestamp"] = ts_df.index.view(int) // 1e9 ts_df.loc[:, "trainval"] = [j < len(train_scores) for j in range(len(ts_df))] ts_df.loc[:, "idx"] = i - df = df.append(ts_df, ignore_index=True) + df = pd.concat((df, ts_df), ignore_index=True) df.to_csv(csv, index=False) # Start from time series i+1 if loading a checkpoint. @@ -358,7 +366,7 @@ def train_model( score = test_scores if tune_on_test else train_scores label = test_anom if tune_on_test else train_anom model.train_post_process( - train_vals, train_result=score, anomaly_labels=label, post_rule_train_config=post_rule_train_config + train_result=score, anomaly_labels=label, post_rule_train_config=post_rule_train_config ) # Log (many) evaluation metrics for the time series @@ -433,7 +441,7 @@ def read_model_predictions(dataset: TSADBaseDataset, model_dir: str): Returns a list of lists all_preds, where all_preds[i] is the model's raw anomaly scores for time series i in the dataset. """ - csv = os.path.join("results", "anomaly", model_dir, f"pred_{dataset_to_name(dataset)}.csv.gz") + csv = os.path.join("results", "anomaly", model_dir, f"pred_{get_dataset_name(dataset)}.csv.gz") preds = pd.read_csv(csv, dtype={"trainval": bool, "idx": int}) preds["timestamp"] = to_pd_datetime(preds["timestamp"]) return [preds[preds["idx"] == i].set_index("timestamp") for i in sorted(preds["idx"].unique())] @@ -460,7 +468,6 @@ def evaluate_predictions( for i, (true, md) in enumerate(tqdm(dataset)): # Get time series for the train & test splits of the ground truth idx = ~md.trainval if tune_on_test else md.trainval - train_vals = df_to_merlion(true[idx], md[idx], transform=resampler) true_train = df_to_merlion(true[idx], md[idx], get_ground_truth=True) true_test = df_to_merlion(true[~md.trainval], md[~md.trainval], get_ground_truth=True) @@ -501,9 +508,7 @@ def evaluate_predictions( m.threshold = m.threshold.to_simple_threshold() if tune_on_test and not unsupervised: m.calibrator.train(TimeSeries.from_pd(og_pred["y"][og_pred["trainval"]])) - m.train_post_process( - train_vals, train_result=train, anomaly_labels=true_train, post_rule_train_config=prtc - ) + m.train_post_process(train_result=train, anomaly_labels=true_train, post_rule_train_config=prtc) models.append(m) # Get the lead & lag time for the dataset @@ -538,7 +543,6 @@ def evaluate_predictions( model.threshold = model.threshold.to_simple_threshold() model.threshold.alm_threshold = threshold model.train_post_process( - train_vals, train_result=pred_train, anomaly_labels=true_train, post_rule_train_config=ensemble_threshold_train_config, @@ -653,7 +657,7 @@ def main(): logging.basicConfig( format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=level ) - dataset = get_dataset(args.dataset) + dataset = get_dataset(args.dataset, rootdir=args.data_root, **args.data_kwargs) retrain_freq, train_window = args.retrain_freq, args.train_window univariate = dataset[0][0].shape[1] == 1 if retrain_freq == "default": @@ -697,10 +701,11 @@ def main(): ) model_name = "+".join(sorted(resolve_model_name(m) for m in args.models)) - summary = os.path.join("results", "anomaly", f"{dataset_to_name(dataset)}_summary.csv") + summary = os.path.join("results", "anomaly", f"{get_dataset_name(dataset)}_summary.csv") if os.path.exists(summary): df = pd.read_csv(summary, index_col=0) else: + os.makedirs(os.path.dirname(summary), exist_ok=True) df = pd.DataFrame() if retrain_freq: model_name += f"_{retrain_freq}" diff --git a/benchmark_forecast.py b/benchmark_forecast.py index 2a75a3a02..6070524ba 100644 --- a/benchmark_forecast.py +++ b/benchmark_forecast.py @@ -59,6 +59,8 @@ def parse_args(): "in ts_datasets/ts_datasets/forecast/__init__.py for " "valid options.", ) + parser.add_argument("--data_root", default=None, help="Root directory/file of dataset.") + parser.add_argument("--data_kwargs", default="{}", help="JSON of keyword arguments for the data loader.") parser.add_argument( "--models", type=str, @@ -122,6 +124,8 @@ def parse_args(): ) args = parser.parse_args() + args.data_kwargs = json.loads(args.data_kwargs) + assert isinstance(args.data_kwargs, dict) # If not summarizing all results, we need at least one model to evaluate if args.summarize and args.models is None: @@ -139,6 +143,9 @@ def get_dataset_name(dataset: BaseDataset) -> str: name = type(dataset).__name__ if hasattr(dataset, "subset") and dataset.subset is not None: name += "_" + dataset.subset + if isinstance(dataset, CustomDataset): + root = dataset.rootdir + name = os.path.join(name, os.path.basename(os.path.dirname(root) if os.path.isfile(root) else root)) return name @@ -238,6 +245,7 @@ def train_model( i0 = pd.read_csv(csv).idx.max() else: i0 = -1 + os.makedirs(os.path.dirname(csv), exist_ok=True) with open(csv, "w") as f: f.write("idx,name,horizon,retrain_type,n_retrain,RMSE,sMAPE\n") @@ -422,7 +430,6 @@ def join_dfs(name2df: Dict[str, pd.DataFrame]) -> pd.DataFrame: def summarize_full_df(full_df: pd.DataFrame) -> pd.DataFrame: # Get the names of all algorithms which have full results algs = [col[len("sMAPE") :] for col in full_df.columns if col.startswith("sMAPE") and not full_df[col].isna().any()] - summary_df = pd.DataFrame({alg.lstrip("_"): [] for alg in algs}) # Compute pooled (per time series) mean/median sMAPE, RMSE @@ -454,7 +461,7 @@ def main(): stream=sys.stdout, level=logging.DEBUG if args.debug else logging.INFO, ) - dataset = get_dataset(args.dataset) + dataset = get_dataset(args.dataset, rootdir=args.data_root, **args.data_kwargs) dataset_name = get_dataset_name(dataset) if len(args.models) > 0: @@ -507,25 +514,27 @@ def main(): f"before trying to summarize their results." ) for csv in sorted(csvs): - model_name = os.path.basename(os.path.dirname(csv)) - suffix = re.search(f"(?<={dataset_name}).*(?=\\.csv)", os.path.basename(csv)).group(0) + basename = re.search(f"{dataset_name}.*\\.csv", csv).group(0) + model_name = os.path.basename(os.path.dirname(csv[: -len(basename)])) + suffix = re.search(f"(?<={dataset_name}).*(?=\\.csv)", basename).group(0) try: name2df[model_name + suffix] = pd.read_csv(csv) except Exception as e: logger.warning(f'Caught {type(e).__name__}: "{e}". Skipping csv file {csv}.') continue - # Join all the dataframes into one + # Join all the dataframes into one & summarize the results dirname = os.path.join(MERLION_ROOT, "results", "forecast") full_df = join_dfs(name2df) - full_df.to_csv(os.path.join(dirname, f"{dataset_name}_full.csv"), index=False) - - # Summarize the joined dataframe summary_df = summarize_full_df(full_df) - summary_df.to_csv(os.path.join(dirname, f"{dataset_name}_summary.csv"), index=True) if args.summarize: print(summary_df) + full_fname, summary_fname = [os.path.join(dirname, f"{dataset_name}_{x}.csv") for x in ["full", "summary"]] + os.makedirs(os.path.dirname(full_fname), exist_ok=True) + full_df.to_csv(full_fname, index=False) + summary_df.to_csv(summary_fname, index=True) + if __name__ == "__main__": main() diff --git a/examples/CustomDataset.ipynb b/examples/CustomDataset.ipynb index 6886c159a..33691a873 100644 --- a/examples/CustomDataset.ipynb +++ b/examples/CustomDataset.ipynb @@ -345,7 +345,7 @@ "source": [ "from ts_datasets.anomaly import CustomAnomalyDataset\n", "dataset = CustomAnomalyDataset(\n", - " root=anom_dir, # where the data is stored\n", + " rootdir=anom_dir, # where the data is stored\n", " test_frac=0.75, # use 75% of each time series for testing. \n", " # overridden if the column `trainval` is in the actual CSV.\n", " time_unit=\"s\", # the timestamp column (automatically detected) is in units of seconds\n", @@ -956,7 +956,7 @@ "source": [ "from ts_datasets.forecast import CustomDataset\n", "dataset = CustomDataset(\n", - " root=csv, # where the data is stored\n", + " rootdir=csv, # where the data is stored\n", " index_cols=[\"Store\", \"Dept\"], # Individual time series are indexed by store & department\n", " test_frac=0.75, # use 25% of each time series for testing. \n", " # overridden if the column `trainval` is in the actual CSV.\n", @@ -1409,7 +1409,27 @@ "source": [ "## Broader Takeaways\n", "\n", - "In general, a dataset can contain any number of CSVs stored under a single root directory. Each CSV can contain one or more time series, where the different time series within a single file are indicated by different values of the index column. Note that this works for anomaly detection as well! You just need to make sure that your CSVs all contain the `anomaly` column. In general, all features supported by `CustomDataset` are also supported by `CustomAnomalyDataset`, as long as your CSV files have the `anomaly` column." + "In general, a dataset can contain any number of CSVs stored under a single root directory. Each CSV can contain one or more time series, where the different time series within a single file are indicated by different values of the index column. Note that this works for anomaly detection as well! You just need to make sure that your CSVs all contain the `anomaly` column. In general, all features supported by `CustomDataset` are also supported by `CustomAnomalyDataset`, as long as your CSV files have the `anomaly` column.\n", + "\n", + "If you want to either of the above custom datasets for benchmarking, you can call\n", + "\n", + "```\n", + "python benchmark_anomaly.py --model IsolationForest --retrain_freq 7d \\\n", + " --dataset CustomAnomalyDataset --data_root data/synthetic_anomaly \\\n", + " --data_kwargs '{\"assume_no_anomaly\": true, \"test_frac\": 0.75}'\n", + "```\n", + "\n", + "or \n", + "\n", + "```\n", + "python benchmark_forecast.py --model AutoETS \\\n", + " --dataset CustomDataset --data_root data/walmart/walmart_mini.csv \\\n", + " --data_kwargs '{\"test_frac\": 0.25, \\\n", + " \"index_cols\": [\"Store\", \"Dept\"], \\\n", + " \"data_cols\": [\"Weekly_Sales\"]}'\n", + "```\n", + "\n", + "Note in the example above, we specify \"data_cols\" as \"Weekly_Sales\". This indicates that we want" ] } ], diff --git a/merlion/evaluate/forecast.py b/merlion/evaluate/forecast.py index bb9714824..9cd2c1ca6 100644 --- a/merlion/evaluate/forecast.py +++ b/merlion/evaluate/forecast.py @@ -46,9 +46,9 @@ def __init__( :param lb (optional): lower bound of 95% prediction interval. This value is used for computing MSIS :param target_seq_index (optional): the index of the target sequence, for multivariate. """ - ground_truth = ground_truth.to_ts() if isinstance(ground_truth, UnivariateTimeSeries) else ground_truth - predict = predict.to_ts() if isinstance(predict, UnivariateTimeSeries) else predict - insample = insample.to_ts() if isinstance(insample, UnivariateTimeSeries) else insample + ground_truth = TimeSeries.from_pd(ground_truth) + predict = TimeSeries.from_pd(predict) + insample = TimeSeries.from_pd(insample) t0, tf = predict.t0, predict.tf ground_truth = ground_truth.window(t0, tf, include_tf=True).align() if target_seq_index is not None: diff --git a/merlion/models/anomaly/base.py b/merlion/models/anomaly/base.py index ed649dc6e..f3621e9da 100644 --- a/merlion/models/anomaly/base.py +++ b/merlion/models/anomaly/base.py @@ -199,16 +199,11 @@ def train( ) def train_post_process( - self, - train_data: TimeSeries, - train_result: Union[TimeSeries, pd.DataFrame], - anomaly_labels=None, - post_rule_train_config=None, + self, train_result: Union[TimeSeries, pd.DataFrame], anomaly_labels=None, post_rule_train_config=None ) -> TimeSeries: """ Converts the train result (anom scores on train data) into a TimeSeries object and trains the post-rule. - :param train_data: a `TimeSeries` of metric values to train the model. :param train_result: Raw anomaly scores on the training data. :param anomaly_labels: a `TimeSeries` indicating which timestamps are anomalous. Optional. :param post_rule_train_config: The config to use for training the model's post-rule. The model's default diff --git a/merlion/models/anomaly/change_point/bocpd.py b/merlion/models/anomaly/change_point/bocpd.py index 701f75304..3c9b27556 100644 --- a/merlion/models/anomaly/change_point/bocpd.py +++ b/merlion/models/anomaly/change_point/bocpd.py @@ -114,13 +114,6 @@ def __init__( self.lag = lag super().__init__(max_forecast_steps=max_forecast_steps, **kwargs) - def to_dict(self, _skipped_keys=None): - _skipped_keys = _skipped_keys if _skipped_keys is not None else set() - config_dict = super().to_dict(_skipped_keys.union({"change_kind"})) - if "change_kind" not in _skipped_keys: - config_dict["change_kind"] = self.change_kind.name - return config_dict - @property def change_kind(self) -> ChangeKind: return self._change_kind diff --git a/merlion/models/anomaly/forecast_based/base.py b/merlion/models/anomaly/forecast_based/base.py index 9b003408d..874d9bf0f 100644 --- a/merlion/models/anomaly/forecast_based/base.py +++ b/merlion/models/anomaly/forecast_based/base.py @@ -8,7 +8,7 @@ Base class for anomaly detectors based on forecasting models. """ import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -65,17 +65,16 @@ def forecast_to_anom_score( def train_post_process( self, - train_data: TimeSeries, - train_result: Tuple[pd.DataFrame, pd.DataFrame], + train_result: Tuple[Union[TimeSeries, pd.DataFrame], Optional[Union[TimeSeries, pd.DataFrame]]], anomaly_labels=None, post_rule_train_config=None, ) -> TimeSeries: if isinstance(train_result, tuple) and len(train_result) == 2: - train_pred, train_err = ForecasterBase.train_post_process(self, train_data, train_result) - train_data = train_data if self.invert_transform else self.transform(train_data) + train_pred, train_err = ForecasterBase.train_post_process(self, train_result) + train_data = self.train_data if self.invert_transform else self.transform(self.train_data) train_result = self.forecast_to_anom_score(train_data, train_pred, train_err) return DetectorBase.train_post_process( - self, train_data, train_result, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config + self, train_result, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config ) def train( diff --git a/merlion/models/anomaly/forecast_based/prophet.py b/merlion/models/anomaly/forecast_based/prophet.py index 4c4012c15..0488ce530 100644 --- a/merlion/models/anomaly/forecast_based/prophet.py +++ b/merlion/models/anomaly/forecast_based/prophet.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -12,11 +12,9 @@ from merlion.models.anomaly.base import DetectorConfig from merlion.models.forecast.prophet import ProphetConfig, Prophet from merlion.post_process.threshold import AggregateAlarms -from merlion.transform.moving_average import DifferenceTransform class ProphetDetectorConfig(ProphetConfig, DetectorConfig): - _default_transform = DifferenceTransform() _default_threshold = AggregateAlarms(alm_threshold=3) diff --git a/merlion/models/automl/autoets.py b/merlion/models/automl/autoets.py index 2a23db5e8..d63254f7c 100644 --- a/merlion/models/automl/autoets.py +++ b/merlion/models/automl/autoets.py @@ -5,24 +5,27 @@ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # """ -Automatic seasonality detection for ETS. +Automatic hyperparamter selection for ETS. """ -import warnings -import logging -import time from copy import deepcopy -from typing import Union, Iterator, Any, Optional, Tuple from itertools import product +import logging +from typing import Union, Iterator, Any, Optional, Tuple +import warnings + import numpy as np +import pandas as pd + from statsmodels.tsa.exponential_smoothing.ets import ETSModel -from merlion.models.forecast.ets import ETS +from merlion.models.forecast.ets import ETS, ETSConfig +from merlion.models.automl.base import InformationCriterion, ICConfig, ICAutoMLForecaster from merlion.models.automl.seasonality import PeriodicityStrategy, SeasonalityConfig, SeasonalityLayer from merlion.utils import TimeSeries, UnivariateTimeSeries logger = logging.getLogger(__name__) -class AutoETSConfig(SeasonalityConfig): +class AutoETSConfig(SeasonalityConfig, ICConfig): """ Configuration class for `AutoETS`. Act as a wrapper around a `ETS` model, which automatically detects the seasonal_periods, error, trend, damped_trend and seasonal. @@ -38,7 +41,7 @@ def __init__( auto_seasonal: bool = True, auto_damped: bool = True, periodicity_strategy: PeriodicityStrategy = PeriodicityStrategy.ACF, - information_criterion: str = "aic", + information_criterion: InformationCriterion = InformationCriterion.AIC, additive_only: bool = False, allow_multiplicative_trend: bool = False, restrict: bool = True, @@ -50,28 +53,29 @@ def __init__( :param auto_trend: Whether to automatically detect the trend components. :param auto_seasonal: Whether to automatically detect the seasonal components. :param auto_damped: Whether to automatically detect the damped trend components. - :param information_criterion: informationc_criterion to select the best model. It can be "aic", - "bic", or "aicc". :param additive_only: If True, the search space will only consider additive models. - :param allow_multiplicative_trend: If True, models with multiplicative trend are allowed in the search - space. + :param allow_multiplicative_trend: If True, models with multiplicative trend are allowed in the search space. :param restrict: If True, the models with infinite variance will not be allowed in the search space. """ model = dict(name="ETS") if model is None else model - super().__init__(model=model, periodicity_strategy=periodicity_strategy, **kwargs) + super().__init__( + model=model, + periodicity_strategy=periodicity_strategy, + information_criterion=information_criterion, + **kwargs, + ) self.auto_seasonality = auto_seasonality self.auto_trend = auto_trend self.auto_seasonal = auto_seasonal self.auto_error = auto_error self.auto_damped = auto_damped - self.information_criterion = information_criterion self.additive_only = additive_only self.allow_multiplicative_trend = allow_multiplicative_trend self.restrict = restrict -class AutoETS(SeasonalityLayer): +class AutoETS(ICAutoMLForecaster, SeasonalityLayer): """ ETS with automatic seasonality detection. """ @@ -80,15 +84,6 @@ class AutoETS(SeasonalityLayer): def __init__(self, config: AutoETSConfig): super().__init__(config) - # results stored in dict - # dict[tuple -> ARIMA] - self._results_dict = dict() - - # dict[tuple -> float] - self._ic_dict = dict() - - self._bestfit = None - self._bestfit_key = None def generate_theta(self, train_data: TimeSeries) -> Iterator: """ @@ -99,12 +94,12 @@ def generate_theta(self, train_data: TimeSeries) -> Iterator: # check the size of y n_samples = y.shape[0] if n_samples <= 3: - self.information_criterion = "aic" + self.information_criterion = InformationCriterion.AIC # auto-detect seasonality if desired, otherwise just get it from seasonal order if self.config.auto_seasonality: - candidate_m = super().generate_theta(train_data=train_data) - m, _, _ = super().evaluate_theta(thetas=candidate_m, train_data=train_data) + candidate_m = SeasonalityLayer.generate_theta(self, train_data=train_data) + m, _, _ = SeasonalityLayer.evaluate_theta(self, thetas=candidate_m, train_data=train_data) else: if self.model.config.seasonal_periods is None: m = 1 @@ -153,66 +148,9 @@ def generate_theta(self, train_data: TimeSeries) -> Iterator: if error == "mul" and trend == "mul" and seasonal == "add": continue - thetas.append([error, trend, seasonal, damped, m]) + thetas.append((error, trend, seasonal, damped, m)) return iter(thetas) - def evaluate_theta( - self, thetas: Iterator, train_data: TimeSeries, train_config=None, **kwargs - ) -> Tuple[Any, Optional[ETS], Optional[Tuple[TimeSeries, Optional[TimeSeries]]]]: - model = deepcopy(self.model) - y = train_data.univariates[self.target_name].to_pd() - for error, trend, seasonal, damped, m in thetas: - start = time.time() - _model_fit = _fit_ets_model(m, error, trend, seasonal, damped, y) - fit_time = time.time() - start - ic = getattr(_model_fit, self.config.information_criterion) - logger.debug( - "{model} : {ic_name}={ic:.3f}, Time={time:.2f} sec".format( - model=_model_name(_model_fit.model), - ic_name=self.config.information_criterion.upper(), - ic=ic, - time=fit_time, - ) - ) - self._results_dict[(error, trend, seasonal, damped, m)] = _model_fit - self._ic_dict[(error, trend, seasonal, damped, m)] = ic - if self._bestfit is None: - self._bestfit = _model_fit - self._bestfit_key = (error, trend, seasonal, damped, m) - logger.debug("First best model found (%.3f)" % ic) - current_ic = self._ic_dict[self._bestfit_key] - if ic < current_ic: - logger.debug("New best model found (%.3f < %.3f)" % (ic, current_ic)) - self._bestfit = _model_fit - self._bestfit_key = (error, trend, seasonal, damped, m) - best_model_theta = self._bestfit_key - - # construct ETS model - self.set_theta(model, best_model_theta, train_data) - model.model = self._bestfit - name = model.target_name - times = y.index - - # match the minimum data size requirement when refitting new data for ETS model - last_train_window_size = 10 - if model.seasonal_periods is not None: - last_train_window_size = max(10, 10 + 2 * (model.seasonal_periods // 2), 2 * model.seasonal_periods) - last_train_window_size = min(last_train_window_size, y.shape[0]) - model.last_train_window = y[-last_train_window_size:] - - # FORECASTING: forecast for next n steps using ETS model - model._n_train = y.shape[0] - model._last_val = y[-1] - - yhat = model.model.fittedvalues - err = model.model.standardized_forecasts_error - train_result = ( - UnivariateTimeSeries(times, yhat, name).to_ts(), - UnivariateTimeSeries(times, err, f"{name}_err").to_ts(), - ) - - return best_model_theta, model, train_result - def set_theta(self, model, theta, train_data: TimeSeries = None): error, trend, seasonal, damped_trend, seasonal_periods = theta model.config.error = error @@ -221,30 +159,10 @@ def set_theta(self, model, theta, train_data: TimeSeries = None): model.config.seasonal = seasonal model.config.seasonal_periods = seasonal_periods + @staticmethod + def _model_name(theta): + error, trend, seasonal, damped, seasonal_periods = theta + return f"ETS(err={error},trend={trend},seas={seasonal},damped={damped})" -def _fit_ets_model(seasonal_periods, error, trend, seasonal, damped, train_data): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - _model_fit = ETSModel( - train_data, - error=error, - trend=trend, - seasonal=seasonal, - damped_trend=damped, - seasonal_periods=seasonal_periods, - ).fit(disp=False) - return _model_fit - - -def _model_name(model_spec): - """ - Return model name - """ - error = model_spec.error if model_spec.error is not None else "None" - trend = model_spec.trend if model_spec.trend is not None else "None" - seasonal = model_spec.seasonal if model_spec.seasonal is not None else "None" - damped_trend = "damped" if model_spec.damped_trend else "no damped" - - return " ETS({error},{trend},{seasonal},{damped_trend})".format( - error=error, trend=trend, seasonal=seasonal, damped_trend=damped_trend - ) + def get_ic(self, model, train_data: pd.DataFrame, train_result: Tuple[pd.DataFrame, pd.DataFrame]) -> float: + return getattr(model.base_model.model, self.information_criterion.name.lower()) diff --git a/merlion/models/automl/autoprophet.py b/merlion/models/automl/autoprophet.py index 9ea298cbc..d8152a82d 100644 --- a/merlion/models/automl/autoprophet.py +++ b/merlion/models/automl/autoprophet.py @@ -5,15 +5,25 @@ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # """ -Automatic (multi)-seasonality detection for Facebook's Prophet. +Automatic hyperparameter selection for Facebook's Prophet. """ -from typing import Union +import copy +import logging +from typing import Any, Iterator, Optional, Tuple, Union +import numpy as np +import pandas as pd +from scipy.stats import norm + +from merlion.models.automl.base import InformationCriterion, ICConfig, ICAutoMLForecaster from merlion.models.automl.seasonality import PeriodicityStrategy, SeasonalityConfig, SeasonalityLayer from merlion.models.forecast.prophet import Prophet +from merlion.utils import TimeSeries + +logger = logging.getLogger(__name__) -class AutoProphetConfig(SeasonalityConfig): +class AutoProphetConfig(SeasonalityConfig, ICConfig): """ Config class for `Prophet` with automatic seasonality detection. """ @@ -22,10 +32,16 @@ def __init__( self, model: Union[Prophet, dict] = None, periodicity_strategy: Union[PeriodicityStrategy, str] = PeriodicityStrategy.All, + information_criterion: InformationCriterion = InformationCriterion.AIC, **kwargs, ): model = dict(name="Prophet") if model is None else model - super().__init__(model=model, periodicity_strategy=periodicity_strategy, **kwargs) + super().__init__( + model=model, + periodicity_strategy=periodicity_strategy, + information_criterion=information_criterion, + **kwargs, + ) @property def multi_seasonality(self): @@ -35,10 +51,41 @@ def multi_seasonality(self): return True -class AutoProphet(SeasonalityLayer): +class AutoProphet(ICAutoMLForecaster, SeasonalityLayer): """ `Prophet` with automatic seasonality detection. Automatically detects and adds additional seasonalities that the existing Prophet may not detect (e.g. hourly). """ config_class = AutoProphetConfig + + def generate_theta(self, train_data: TimeSeries) -> Iterator: + seasonalities = list(super().generate_theta(train_data)) + seasonality_modes = ["additive", "multiplicative"] + return ((seasonalities, mode) for mode in seasonality_modes) + + def set_theta(self, model, theta, train_data: TimeSeries = None): + seasonalities, seasonality_mode = theta + seasonalities, _, _ = SeasonalityLayer.evaluate_theta(self, thetas=iter(seasonalities), train_data=train_data) + SeasonalityLayer.set_theta(self, model=model, theta=seasonalities, train_data=train_data) + model.base_model.config.seasonality_mode = seasonality_mode + model.base_model.model.seasonality_mode = seasonality_mode + + def _model_name(self, theta) -> str: + seas, mode = theta + return f"Prophet(seasonalities={seas}, seasonality_mode={mode})" + + def get_ic(self, model, train_data: pd.DataFrame, train_result: Tuple[pd.DataFrame, pd.DataFrame]) -> float: + pred, stderr = train_result + log_like = norm.logpdf((pred.values - train_data.values) / stderr.values).sum() + n_params = sum(len(v.flatten()) for k, v in model.base_model.model.params.items() if k != "trend") + ic_id = self.config.information_criterion + if ic_id is InformationCriterion.AIC: + ic = 2 * n_params - 2 * log_like.sum() + elif ic_id is InformationCriterion.BIC: + ic = n_params * np.log(len(train_data)) - 2 * log_like + elif ic_id is InformationCriterion.AICc: + ic = 2 * n_params - 2 * log_like + (2 * n_params * (n_params + 1)) / max(1, len(train_data) - n_params - 1) + else: + raise ValueError(f"{type(self.model).__name__} doesn't support information criterion {ic_id.name}") + return ic diff --git a/merlion/models/automl/base.py b/merlion/models/automl/base.py index 9fa4561ef..9b6feeead 100644 --- a/merlion/models/automl/base.py +++ b/merlion/models/automl/base.py @@ -9,12 +9,19 @@ """ from abc import abstractmethod from copy import deepcopy -from typing import Any, Iterator, Optional, Tuple +from enum import Enum, auto +import logging +from typing import Any, Iterator, Optional, Tuple, Union +import time -from merlion.models.layers import ModelBase, LayeredModel, ForecastingDetectorBase +import pandas as pd + +from merlion.models.layers import Config, ModelBase, LayeredModel, LayeredModelConfig, ForecasterBase from merlion.utils import TimeSeries from merlion.utils.misc import AutodocABCMeta +logger = logging.getLogger(__name__) + class AutoMLMixIn(LayeredModel, metaclass=AutodocABCMeta): """ @@ -28,24 +35,24 @@ def train_model(self, train_data: TimeSeries, train_config=None, **kwargs): :param train_data: the data to train on. :param train_config: the train config of the underlying model (optional). """ - processed_train_data = self.model.train_pre_process(train_data) # no need to call in generate/evaluate theta + # don't call train_pre_process() in generate/evaluate theta. get model.train_data for the original train data. + processed_train_data = self.model.train_pre_process(train_data) candidate_thetas = self.generate_theta(processed_train_data) theta, model, train_result = self.evaluate_theta(candidate_thetas, processed_train_data, **kwargs) if model is not None: - train_result = model.train_post_process(train_data, train_result, **kwargs) self.model = model - return train_result + return model.train_post_process(train_result, **kwargs) else: model = deepcopy(self.model) model.reset() - self.set_theta(model, theta, train_data) + self.set_theta(model, theta, processed_train_data) self.model = model return super().train_model(train_data, **kwargs) @abstractmethod def generate_theta(self, train_data: TimeSeries) -> Iterator: r""" - :param train_data: Training data to use for generation of hyperparameters :math:`\theta` + :param train_data: Pre-processed training data to use for generation of hyperparameters :math:`\theta` Returns an iterator of hyperparameter candidates for consideration with th underlying model. """ @@ -57,7 +64,7 @@ def evaluate_theta( ) -> Tuple[Any, Optional[ModelBase], Optional[Tuple[TimeSeries, Optional[TimeSeries]]]]: r""" :param thetas: Iterator of the hyperparameter candidates - :param train_data: Training data + :param train_data: Pre-processed training data :param train_config: Training configuration Return the optimal hyperparameter, as well as optionally a model and result of the training procedure. @@ -69,9 +76,128 @@ def set_theta(self, model, theta, train_data: TimeSeries = None): r""" :param model: Underlying base model to which the new theta is applied :param theta: Hyperparameter to apply - :param train_data: Training data (Optional) + :param train_data: Pre-processed training data (Optional) Sets the hyperparameter to the provided ``model``. This is used to apply the :math:`\theta` to the model, since this behavior is custom to every model. Oftentimes in internal implementations, ``model`` is the optimal model. """ raise NotImplementedError + + +class InformationCriterion(Enum): + AIC = auto() + r""" + Akaike information criterion. Computed as + + .. math:: + \mathrm{AIC} = 2k - 2\mathrm{ln}(L) + + where k is the number of parameters, and L is the model's likelihood. + """ + + BIC = auto() + r""" + Bayesian information criterion. Computed as + + .. math:: + k \mathrm{ln}(n) - 2 \mathrm{ln}(L) + + where n is the sample size, k is the number of parameters, and L is the model's likelihood. + """ + + AICc = auto() + r""" + Akaike information criterion with correction for small sample size. Computed as + + .. math:: + \mathrm{AICc} = \mathrm{AIC} + \frac{2k^2 + 2k}{n - k - 1} + + where n is the sample size, and k is the number of paramters. + """ + + +class ICConfig(Config): + """ + Mix-in to add an information criterion parameter to a model config. + """ + + def __init__(self, information_criterion: InformationCriterion = InformationCriterion.AIC, **kwargs): + """ + :param information_criterion: information criterion to select the best model. + """ + super().__init__(**kwargs) + self.information_criterion = information_criterion + + @property + def information_criterion(self): + return self._information_criterion + + @information_criterion.setter + def information_criterion(self, ic: Union[InformationCriterion, str]): + if not isinstance(ic, InformationCriterion): + valid = {k.lower(): k for k in InformationCriterion.__members__} + assert ic.lower() in valid, f"Unsupported InformationCriterion {ic}. Supported values: {valid.values()}" + ic = InformationCriterion[valid[ic.lower()]] + self._information_criterion = ic + + +class ICAutoMLForecaster(AutoMLMixIn, ForecasterBase, metaclass=AutodocABCMeta): + """ + AutoML model which uses an information criterion to determine which model paramters are best. + """ + + config_class = ICConfig + + @property + def information_criterion(self): + return self.config.information_criterion + + @abstractmethod + def get_ic( + self, model, train_data: pd.DataFrame, train_result: Tuple[pd.DataFrame, Optional[pd.DataFrame]] + ) -> float: + """ + Returns the information criterion of the model based on the given training data & the model's train result. + + :param model: One of the models being tried. Must be trained. + :param train_data: The target sequence of the training data as a ``pandas.DataFrame``. + :param train_result: The result of calling ``model._train()``. + :return: The information criterion evaluating the model's goodness of fit. + """ + raise NotImplementedError + + @abstractmethod + def _model_name(self, theta) -> str: + """ + :return: a string describing the current model. + """ + + def evaluate_theta( + self, thetas: Iterator, train_data: TimeSeries, train_config=None, **kwargs + ) -> Tuple[Any, ModelBase, Tuple[TimeSeries, Optional[TimeSeries]]]: + best = None + y = train_data.to_pd() + y_target = pd.DataFrame(y[self.model.target_name]) + for theta in thetas: + # Start timer & fit model using the current theta + start = time.time() + model = deepcopy(self.model) + self.set_theta(model, theta, train_data) + train_result = model._train(y, train_config=train_config) + fit_time = time.time() - start + ic = float(self.get_ic(model=model, train_data=y_target, train_result=train_result)) + logger.debug(f"{self._model_name(theta)}: {self.information_criterion.name}={ic:.3f}, Time={fit_time:.2f}s") + + # Determine if current model is better than the best seen yet + curr = {"theta": theta, "model": model, "train_result": train_result, "ic": ic} + if best is None: + best = curr + logger.debug("First best model found (%.3f)" % ic) + current_ic = best["ic"] + if ic < current_ic: + logger.debug("New best model found (%.3f < %.3f)" % (ic, current_ic)) + best = curr + + # Return best model after post-processing its train result + theta, model, train_result = best["theta"], best["model"], best["train_result"] + return theta, model, model.train_post_process(train_result, **kwargs) diff --git a/merlion/models/automl/seasonality.py b/merlion/models/automl/seasonality.py index b4779206f..3c3c95a39 100644 --- a/merlion/models/automl/seasonality.py +++ b/merlion/models/automl/seasonality.py @@ -12,6 +12,11 @@ import logging from typing import Any, Iterator, Optional, Tuple, Union +import numpy as np +from scipy.signal import argrelmax +from scipy.stats import norm +import statsmodels.api as sm + from merlion.models.automl.base import AutoMLMixIn from merlion.models.base import ModelBase from merlion.models.layers import LayeredModelConfig @@ -71,7 +76,7 @@ class SeasonalityConfig(LayeredModelConfig): _default_transform = TemporalResample() - def __init__(self, model, periodicity_strategy=PeriodicityStrategy.ACF, pval: float = 0.05, max_lag = None, **kwargs): + def __init__(self, model, periodicity_strategy=PeriodicityStrategy.ACF, pval: float = 0.05, max_lag=None, **kwargs): """ :param periodicity_strategy: Strategy to choose the seasonality if multiple candidates are detected. :param pval: p-value for deciding whether a detected seasonality is statistically significant. @@ -101,7 +106,7 @@ def periodicity_strategy(self) -> PeriodicityStrategy: def periodicity_strategy(self, p: Union[PeriodicityStrategy, str]): if not isinstance(p, PeriodicityStrategy): valid = {k.lower(): k for k in PeriodicityStrategy.__members__} - assert p.lower() in valid, f"Unsupported PeriodicityStrategy {p}. Supported strategies are: {valid.keys()}" + assert p.lower() in valid, f"Unsupported PeriodicityStrategy {p}. Supported values: {valid.values()}" p = PeriodicityStrategy[valid[p.lower()]] if p is PeriodicityStrategy.All and not self.multi_seasonality: @@ -111,18 +116,13 @@ def periodicity_strategy(self, p: Union[PeriodicityStrategy, str]): self._periodicity_strategy = p - def to_dict(self, _skipped_keys=None): - _skipped_keys = _skipped_keys if _skipped_keys is not None else set() - config_dict = super().to_dict(_skipped_keys.union({"periodicity_strategy"})) - if "periodicity_strategy" not in _skipped_keys: - config_dict["periodicity_strategy"] = self.periodicity_strategy.name - return config_dict - class SeasonalityLayer(AutoMLMixIn, metaclass=AutodocABCMeta): """ - Seasonality Layer that uses AutoSARIMA-like methods to determine seasonality of your data. Can be used directly on - any model that implements `SeasonalityModel` class. + Seasonality Layer that uses automatically determines the seasonality of your data. Can be used directly on + any model that implements `SeasonalityModel` class. The algorithmic idea is from the + `theta method `__. We find a set of + multiple candidate seasonalites, and we return the best one(s) based on the `PeriodicityStrategy`. """ config_class = SeasonalityConfig @@ -170,25 +170,60 @@ def evaluate_theta( self, thetas: Iterator, train_data: TimeSeries, train_config=None, **kwargs ) -> Tuple[Any, Optional[ModelBase], Optional[Tuple[TimeSeries, Optional[TimeSeries]]]]: # If multiple seasonalities are supported, return a list of all detected seasonalities - thetas = list(thetas) + return list(thetas) if self.config.multi_seasonality else next(thetas), None, None + + def generate_theta(self, train_data: TimeSeries) -> Iterator: + # compute max lag & acf function + x = train_data.univariates[self.target_name].np_values + if self.max_lag is None: + max_lag = max(min(int(10 * np.log10(x.shape[0])), x.shape[0] - 1), 40) + else: + max_lag = self.max_lag + xacf = sm.tsa.acf(x, nlags=max_lag, fft=False) + xacf[np.isnan(xacf)] = 0 + + # select the local maximum points with acf > 0 + candidates = np.intersect1d(np.where(xacf > 0), argrelmax(xacf)[0]) + + # the periods should be smaller than one half of the length of time series + candidates = candidates[candidates < int(x.shape[0] / 2)] + if candidates.shape[0] == 0: + return [] + else: + candidates_idx = [] + if candidates.shape[0] == 1: + candidates_idx += [0] + else: + if xacf[candidates[0]] > xacf[candidates[1]]: + candidates_idx += [0] + if xacf[candidates[-1]] > xacf[candidates[-2]]: + candidates_idx += [-1] + candidates_idx += argrelmax(xacf[candidates])[0].tolist() + candidates = candidates[candidates_idx] + + # statistical test if acf is significant w.r.t a normal distribution + xacf = xacf[1:] + tcrit = norm.ppf(1 - self.pval / 2) + clim = tcrit / np.sqrt(x.shape[0]) * np.sqrt(np.cumsum(np.insert(np.square(xacf) * 2, 0, 1))) + candidates = candidates[xacf[candidates - 1] > clim[candidates - 1]] + + # sort candidates by ACF value + candidates = sorted(candidates.tolist(), key=lambda c: xacf[c - 1], reverse=True) + if len(candidates) == 0: + candidates = [1] + + # choose the desired candidates based on periodicity strategy if self.periodicity_strategy is PeriodicityStrategy.ACF: - thetas = [thetas[0]] + candidates = [candidates[0]] elif self.periodicity_strategy is PeriodicityStrategy.Min: - thetas = [min(thetas)] + candidates = [min(candidates)] elif self.periodicity_strategy is PeriodicityStrategy.Max: - thetas = [max(thetas)] + candidates = [max(candidates)] elif self.periodicity_strategy is PeriodicityStrategy.All: - thetas = thetas + candidates = candidates else: raise ValueError(f"Periodicity strategy {self.periodicity_strategy} not supported.") - theta = thetas if self.config.multi_seasonality else thetas[0] - if thetas != [1]: - logger.info(f"Automatically detect the periodicity is {str(thetas)}") - return theta, None, None - def generate_theta(self, train_data: TimeSeries) -> Iterator: - y = train_data.univariates[self.target_name] - periods = autosarima_utils.multiperiodicity_detection(y, pval=self.pval, max_lag=self.max_lag) - if len(periods) == 0: - periods = [1] - return iter(periods) + if candidates[: None if self.config.multi_seasonality else 1] != [1]: + logger.info(f"Automatically detect the periodicity is {candidates}") + return iter(candidates) diff --git a/merlion/models/base.py b/merlion/models/base.py index 551cbad74..a442c2f55 100644 --- a/merlion/models/base.py +++ b/merlion/models/base.py @@ -9,6 +9,7 @@ """ from abc import abstractmethod import copy +from enum import Enum import json import logging import os @@ -53,13 +54,6 @@ def __init__(self, transform: TransformBase = None, **kwargs): self.transform = transform self.dim = None - @property - def base_model(self): - """ - The base model of a base model is itself. - """ - return self - def to_dict(self, _skipped_keys=None): """ :return: dict with keyword arguments used to initialize the config class. @@ -71,6 +65,8 @@ def to_dict(self, _skipped_keys=None): key = k_strip if hasattr(self, k_strip) else key if hasattr(value, "to_dict"): value = value.to_dict() + elif isinstance(value, Enum): + value = value.name # Relies on there being an appropriate getter/setter! if key not in skipped_keys: config_dict[key] = copy.deepcopy(value) return config_dict @@ -182,6 +178,13 @@ def reset(self): """ self.__init__(self.config) + @property + def base_model(self): + """ + The base model of a base model is itself. + """ + return self + @property @abstractmethod def require_even_sampling(self) -> bool: @@ -327,16 +330,16 @@ def train(self, train_data: TimeSeries, train_config=None, *args, **kwargs): """ if train_config is None: train_config = copy.deepcopy(self._default_train_config) - train_data_processed = self.train_pre_process(train_data).to_pd() - train_result = self._train(train_data=train_data_processed, train_config=train_config) - return self.train_post_process(train_data, train_result, *args, **kwargs) + train_data = self.train_pre_process(train_data).to_pd() + train_result = self._train(train_data=train_data, train_config=train_config) + return self.train_post_process(train_result, *args, **kwargs) @abstractmethod def _train(self, train_data: pd.DataFrame, train_config=None): raise NotImplementedError @abstractmethod - def train_post_process(self, train_data, train_result, *args, **kwargs): + def train_post_process(self, train_result, *args, **kwargs): raise NotImplementedError def _save_state(self, state_dict: Dict[str, Any], filename: str = None, **save_config) -> Dict[str, Any]: diff --git a/merlion/models/ensemble/anomaly.py b/merlion/models/ensemble/anomaly.py index 54d051e1a..1db1ed981 100644 --- a/merlion/models/ensemble/anomaly.py +++ b/merlion/models/ensemble/anomaly.py @@ -199,10 +199,7 @@ def train( # Train the model-level post-rule self.train_post_process( - train_data=train_data, - train_result=combined, - anomaly_labels=anomaly_labels, - post_rule_train_config=post_rule_train_config, + train_result=combined, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config ) return combined diff --git a/merlion/models/forecast/base.py b/merlion/models/forecast/base.py index d7946e938..d7d322f39 100644 --- a/merlion/models/forecast/base.py +++ b/merlion/models/forecast/base.py @@ -176,7 +176,7 @@ def train(self, train_data: TimeSeries, train_config=None) -> Tuple[TimeSeries, return super().train(train_data=train_data, train_config=train_config) def train_post_process( - self, train_data: TimeSeries, train_result: Tuple[pd.DataFrame, pd.DataFrame] + self, train_result: Tuple[Union[TimeSeries, pd.DataFrame], Optional[Union[TimeSeries, pd.DataFrame]]] ) -> Tuple[TimeSeries, TimeSeries]: """ Converts the train result (forecast & stderr for training data) into TimeSeries objects, and inverts the diff --git a/merlion/models/forecast/prophet.py b/merlion/models/forecast/prophet.py index 45a8bb7b2..62b8b0b65 100644 --- a/merlion/models/forecast/prophet.py +++ b/merlion/models/forecast/prophet.py @@ -186,7 +186,7 @@ def set_seasonality(self, theta, train_data: UnivariateTimeSeries): for p in theta: if p > 1: period = p * dt.total_seconds() / 86400 - logger.info(f"Add seasonality {str(p)} ({p * dt})") + logger.debug(f"Add seasonality {str(p)} ({p * dt})") self.model.add_seasonality(name=f"extra_season_{p}", period=period, fourier_order=p) def _train(self, train_data: pd.DataFrame, train_config=None): @@ -196,15 +196,14 @@ def _train(self, train_data: pd.DataFrame, train_config=None): with _suppress_stdout_stderr(): self.model.fit(df) - # Get & return prediction & errors for train data + # Get & return prediction & errors for train data. + # sigma computation based on https://github.com/facebook/prophet/issues/549#issuecomment-435482584 self.model.uncertainty_samples = 0 forecast = self.model.predict(df)["yhat"].values.tolist() + sigma = (self.model.params["sigma_obs"] * self.model.y_scale).item() self.model.uncertainty_samples = self.uncertainty_samples - samples = self.model.predictive_samples(df)["yhat"] - samples = samples - np.expand_dims(forecast, -1) - yhat = pd.DataFrame(forecast, index=df.ds, columns=[self.target_name]) - err = pd.DataFrame(np.std(samples, axis=-1), index=df.ds, columns=[f"{self.target_name}_err"]) + err = pd.DataFrame(sigma, index=df.ds, columns=[f"{self.target_name}_err"]) return yhat, err def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_prev: TimeSeries = None): diff --git a/merlion/transform/factory.py b/merlion/transform/factory.py index 45fad8db7..51e23d9e4 100644 --- a/merlion/transform/factory.py +++ b/merlion/transform/factory.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -23,7 +23,7 @@ AbsVal="merlion.transform.normalize:AbsVal", MeanVarNormalize="merlion.transform.normalize:MeanVarNormalize", MinMaxNormalize="merlion.transform.normalize:MinMaxNormalize", - PowerTransform="merlion.transform.normalize:PowerTransform", + BoxCoxTransform="merlion.transform.normalize:BoxCoxTransform", TemporalResample="merlion.transform.resample:TemporalResample", Shingle="merlion.transform.resample:Shingle", TransformSequence="merlion.transform.sequence:TransformSequence", diff --git a/merlion/transform/normalize.py b/merlion/transform/normalize.py index 96bbe59c9..498f42bcc 100644 --- a/merlion/transform/normalize.py +++ b/merlion/transform/normalize.py @@ -8,16 +8,20 @@ Transforms that rescale the input or otherwise normalize it. """ from collections import OrderedDict +import logging from typing import Iterable import numpy as np import pandas as pd import scipy.special +import scipy.stats from sklearn.preprocessing import StandardScaler from merlion.transform.base import InvertibleTransformBase, TransformBase from merlion.utils import UnivariateTimeSeries, TimeSeries +logger = logging.getLogger(__name__) + class AbsVal(TransformBase): """ @@ -155,16 +159,21 @@ def train(self, time_series: TimeSeries): self.scale = scale -class PowerTransform(InvertibleTransformBase): +class BoxCoxTransform(InvertibleTransformBase): """ Applies the Box-Cox power transform to the time series, with power lmbda. + When lmbda is None, we When lmbda > 0, it is ((x + offset) ** lmbda - 1) / lmbda. When lmbda == 0, it is ln(lmbda + offset). """ - def __init__(self, lmbda=0.0, offset=0.0): + def __init__(self, lmbda=None, offset=0.0): super().__init__() - assert lmbda >= 0 + if lmbda is not None: + if isinstance(lmbda, list): + assert all(isinstance(x, (int, float)) for x in lmbda) + else: + assert isinstance(lmbda, (int, float)) self.lmbda = lmbda self.offset = offset @@ -176,12 +185,17 @@ def requires_inversion_state(self): return False def train(self, time_series: TimeSeries): - pass + if self.lmbda is None: + self.lmbda = [scipy.stats.boxcox(var.np_values + self.offset)[1] for var in time_series.univariates] + logger.info(f"Chose Box-Cox lambda = {self.lmbda}") + elif not isinstance(self.lmbda, list): + self.lmbda = [self.lmbda] * time_series.dim + assert len(self.lmbda) == time_series.dim def __call__(self, time_series: TimeSeries) -> TimeSeries: new_vars = [] - for var in time_series.univariates: - y = scipy.special.boxcox(var + self.offset, self.lmbda) + for lmbda, var in zip(self.lmbda, time_series.univariates): + y = scipy.special.boxcox(var + self.offset, lmbda) var = pd.Series(y, index=var.index, name=var.name) new_vars.append(UnivariateTimeSeries.from_pd(var)) @@ -189,9 +203,14 @@ def __call__(self, time_series: TimeSeries) -> TimeSeries: def _invert(self, time_series: TimeSeries) -> TimeSeries: new_vars = [] - for var in time_series.univariates: - if self.lmbda > 0: - var = (self.lmbda * var + 1).log() / self.lmbda - new_vars.append(UnivariateTimeSeries.from_pd(var.apply(np.exp))) + for lmbda, var in zip(self.lmbda, time_series.univariates): + if lmbda > 0: + var = (lmbda * var + 1) ** (1 / lmbda) + nanvals = var.isna() + if nanvals.any(): + var[nanvals] = 0 + else: + var = var.apply(np.exp) + new_vars.append(UnivariateTimeSeries.from_pd(var - self.offset)) return TimeSeries(new_vars) diff --git a/merlion/utils/autosarima_utils.py b/merlion/utils/autosarima_utils.py index 8781211cb..ca15bd1b2 100644 --- a/merlion/utils/autosarima_utils.py +++ b/merlion/utils/autosarima_utils.py @@ -14,8 +14,6 @@ import numpy as np from numpy.linalg import LinAlgError -from scipy.signal import argrelmax -from scipy.stats import norm import statsmodels.api as sm logger = logging.getLogger(__name__) @@ -230,49 +228,6 @@ def detect_maxiter_sarima_model(y, X, d, D, m, method, information_criterion, ** return maxiter -def multiperiodicity_detection(x, pval=0.05, max_lag=None): - """ - Detect multiple periodicity of a time series - The idea can be found in theta method - (https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R). - Returns a list of periods, which indicates the seasonal periods of the - time series - """ - tcrit = norm.ppf(1 - pval / 2) - if max_lag is None: - max_lag = max(min(int(10 * np.log10(x.shape[0])), x.shape[0] - 1), 40) - xacf = sm.tsa.acf(x, nlags=max_lag, fft=False) - xacf[np.isnan(xacf)] = 0 - - # select the local maximum points with acf > 0 - candidates = np.intersect1d(np.where(xacf > 0), argrelmax(xacf)[0]) - - # the periods should be smaller than one half of the length of time series - candidates = candidates[candidates < int(x.shape[0] / 2)] - if candidates.shape[0] == 0: - return [] - else: - candidates_idx = [] - if candidates.shape[0] == 1: - candidates_idx += [0] - else: - if xacf[candidates[0]] > xacf[candidates[1]]: - candidates_idx += [0] - if xacf[candidates[-1]] > xacf[candidates[-2]]: - candidates_idx += [-1] - candidates_idx += argrelmax(xacf[candidates])[0].tolist() - candidates = candidates[candidates_idx] - - xacf = xacf[1:] - clim = tcrit / np.sqrt(x.shape[0]) * np.sqrt(np.cumsum(np.insert(np.square(xacf) * 2, 0, 1))) - - # statistical test if acf is significant w.r.t a normal distribution - candidate_filter = candidates[xacf[candidates - 1] > clim[candidates - 1]] - # return candidate seasonalities, sorted by ACF value - candidate_filter = sorted(candidate_filter.tolist(), key=lambda c: xacf[c - 1], reverse=True) - return candidate_filter - - def seas_seasonalstationaritytest(x, m): """ Estimate the strength of seasonal component. The idea can be found in diff --git a/merlion/utils/misc.py b/merlion/utils/misc.py index 7829a62ac..89ccd08c7 100644 --- a/merlion/utils/misc.py +++ b/merlion/utils/misc.py @@ -148,7 +148,7 @@ def dynamic_import(import_path: str, alias: dict = None): Dynamically import a member from the specified module. :param import_path: syntax 'module_name:member_name', - e.g. 'merlion.transform.normalize:PowerTransform' + e.g. 'merlion.transform.normalize:BoxCoxTransform' :param alias: dict which maps shortcuts for the registered classes, to their full import paths. :return: imported class diff --git a/setup.py b/setup.py index 9bc927af4..e0d9631a8 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def read_file(fname): setup( name="salesforce-merlion", - version="1.2.5", + version="1.3.0", author=", ".join(read_file("AUTHORS.md").split("\n")), author_email="abhatnagar@salesforce.com", description="Merlion: A Machine Learning Framework for Time Series Intelligence", @@ -48,7 +48,7 @@ def read_file(fname): "packaging", "pandas>=1.1.0", # >=1.1.0 for origin kwarg to df.resample() "prophet>=1.1; python_version >= '3.7'", # 1.1 removes dependency on pystan - "prophet==1.0.1; python_version < '3.7'", # however, prophet 1.1 requires python 3.7+ + "prophet>=1.0; python_version < '3.7'", # however, prophet 1.1 requires python 3.7+ "scikit-learn>=0.22", # >=0.22 for changes to isolation forest algorithm "scipy>=1.6.0; python_version >= '3.7'", # 1.6.0 adds multivariate_t density to scipy.stats "scipy>=1.5.0; python_version < '3.7'", # however, scipy 1.6.0 requires python 3.7+ diff --git a/tests/anomaly/forecast_based/test_prophet.py b/tests/anomaly/forecast_based/test_prophet.py index 7401fee3e..59d53b93c 100644 --- a/tests/anomaly/forecast_based/test_prophet.py +++ b/tests/anomaly/forecast_based/test_prophet.py @@ -17,7 +17,7 @@ from merlion.models.anomaly.forecast_based.prophet import ProphetDetector, ProphetDetectorConfig from merlion.utils.data_io import csv_to_time_series from merlion.utils.time_series import TimeSeries -from merlion.transform.normalize import PowerTransform +from merlion.transform.normalize import BoxCoxTransform from merlion.transform.resample import TemporalResample logger = logging.getLogger(__name__) @@ -33,14 +33,14 @@ def __init__(self, *args, **kwargs): logger.info(f"Data looks like:\n{self.data[:5]}") holidays = pd.DataFrame({"ds": ["03-17-2020"], "holiday": ["St. Patrick's Day"]}) - # Test Prophet with a log transform (Box-Cox with lmbda=0) + # Test Prophet with a Box-Cox transform self.test_len = math.ceil(len(self.data) / 5) self.vals_train = self.data[: -self.test_len] self.vals_test = self.data[-self.test_len :] self.model = AutoProphet( model=ProphetDetector( ProphetDetectorConfig( - transform=PowerTransform(lmbda=0.0), + transform=BoxCoxTransform(lmbda=0.5), uncertainty_samples=1000, holidays=holidays, invert_transform=True, @@ -76,12 +76,10 @@ def test_full(self): ) # score function returns the raw anomaly scores - scores = self.model.get_anomaly_score(self.vals_test) - self.assertEqual(len(scores), len(self.vals_test)) + scores = self.model.get_anomaly_score(self.vals_test).to_pd() logger.info(f"Scores look like:\n{scores[:5]}") - scores = scores.to_pd().values.flatten() - logger.info("max score = " + str(max(scores))) - logger.info("min score = " + str(min(scores)) + "\n") + logger.info("max score = " + str(np.max(scores.values.flatten()))) + logger.info("min score = " + str(np.min(scores.values.flatten())) + "\n") # alarm function returns the post-rule processed anomaly scores alarms = self.model.get_anomaly_label(self.vals_test) @@ -91,22 +89,20 @@ def test_full(self): self.assertLessEqual(n_alarms, 15) logger.info("Verifying that scores don't change much on re-evaluation...\n") - scoresv2 = self.model.get_anomaly_score(self.vals_test, self.vals_train) - scoresv2 = scoresv2.to_pd().values.flatten() - self.assertAlmostEqual(np.mean(np.abs(scores - scoresv2)), 0, delta=0.05) + scoresv2 = self.model.get_anomaly_score(self.vals_test, self.vals_train).to_pd().loc[scores.index] + self.assertAlmostEqual(np.mean(np.abs(scores - scoresv2)).item(), 0, delta=0.05) # We test save/load AFTER our first prediction because we need the old # posterior samples for reproducibility logger.info("Verifying that scores don't change much after save/load...\n") self.model.save(dirname=join(rootdir, "tmp", "prophet")) loaded_model = AutoProphet.load(dirname=join(rootdir, "tmp", "prophet")) - scoresv3 = loaded_model.get_anomaly_score(self.vals_test) - scoresv3 = scoresv3.to_pd().values.flatten() - self.assertAlmostEqual(np.mean(np.abs(scores - scoresv3)), 0, delta=0.05) + scoresv3 = loaded_model.get_anomaly_score(self.vals_test).to_pd() + self.assertAlmostEqual(np.mean(np.abs(scores - scoresv3)).item(), 0, delta=0.05) if __name__ == "__main__": logging.basicConfig( - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.INFO + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.DEBUG ) unittest.main() diff --git a/tests/forecast/test_autoets.py b/tests/forecast/test_autoets.py index 21ca3d0ba..b5cf09dc3 100644 --- a/tests/forecast/test_autoets.py +++ b/tests/forecast/test_autoets.py @@ -24,39 +24,164 @@ class TestETS(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - train_data = np.array([49749475.08, 48334704.82, 48275157.57, 43969281.56, 46870666.51, - 45924937.95, 44988678.08, 44133701.02, 50423887.67, 47365181.62, - 45183562.86, 44733997.96, 43705038.61, 48503231.02, 45329947.95, - 45119948.58, 47757383.6 , 50188444.77, 47826347.76, 47621799.23, - 46608866.38, 48917390.8 , 47899499.45, 46243813.08, 44888625.35, - 44630086.18, 48204925.76, 46464361.88, 47060765.68, 45909681.68, - 47194223.17, 45633673.84, 43081657.51, 41358352.79, 42239826.37, - 45102996.78, 43149465.36, 43066645.59, 43602653.58, 45782124.16, - 46124776.79, 45125584.18, 65821010.24, 49909055.87, 55666423.99, - 61820270.05, 80931028.42, 40432678.46, 42775937.09, 40673653.3 , - 40654715.99, 39599811.1 , 46153156.48, 47336192.79, 48716033.27, - 44125860.84, 46980667.15, 44627510.46, 44872297.41, 42876467.62, - 43459171.87, 45887401.96, 44973284.11, 48676717.62, 43529990.87, - 46861899.62, 45446059.32, 44046586.63, 45293460.06, 48772039.02, - 47669719.65, 47447526.22, 45884025.5 , 47578624.19, 47859258.4 , - 45515914.52, 45274339.01, 43683243.92, 48015073.36, 46249598.39, - 46917127.1 , 47416905.45, 45376432.48, 46763000.79, 43792604.96, - 42716522.23, 42195007.11, 47211092.36, 44374292.45, 45819133.11, - 45855520.28, 48654605.42, 48474212.41, 46438802.12, 66565730.02, - 49384577.42, 55558067.21, 60082608.34, 76994571.24, 46041807.5 , - 44955394.34, 42022604.47, 42080268.78, 39834180.38, 46085308.65, - 50007132.23, 50195096.99, 45770682.88, 46860478.29, 47479695.77, - 46900921.15, 44993618.49, 45272511.84, 53502282.02, 46629165.55, - 45072409.05, 43716752.65, 47124072.67, 46925658.31, 46823296.04, - 47892700.9 , 48281635.03, 49651162.2 , 48412125.56, 47668135.12, - 46597053.1 , 51253099.81, 46099747.96, 46059299.15, 44097050.8 ]) - test_data = np.array([47485712.35, 47403448.22, 47355041.71, 47447284.43, 47159585.95, - 48329845.27, 44225749.26, 44354360.72, 43734839.24, 47566467.14, - 46128371.47, 45122294.9 , 45544036.17]) + train_data = np.array( + [ + 49749475.08, + 48334704.82, + 48275157.57, + 43969281.56, + 46870666.51, + 45924937.95, + 44988678.08, + 44133701.02, + 50423887.67, + 47365181.62, + 45183562.86, + 44733997.96, + 43705038.61, + 48503231.02, + 45329947.95, + 45119948.58, + 47757383.6, + 50188444.77, + 47826347.76, + 47621799.23, + 46608866.38, + 48917390.8, + 47899499.45, + 46243813.08, + 44888625.35, + 44630086.18, + 48204925.76, + 46464361.88, + 47060765.68, + 45909681.68, + 47194223.17, + 45633673.84, + 43081657.51, + 41358352.79, + 42239826.37, + 45102996.78, + 43149465.36, + 43066645.59, + 43602653.58, + 45782124.16, + 46124776.79, + 45125584.18, + 65821010.24, + 49909055.87, + 55666423.99, + 61820270.05, + 80931028.42, + 40432678.46, + 42775937.09, + 40673653.3, + 40654715.99, + 39599811.1, + 46153156.48, + 47336192.79, + 48716033.27, + 44125860.84, + 46980667.15, + 44627510.46, + 44872297.41, + 42876467.62, + 43459171.87, + 45887401.96, + 44973284.11, + 48676717.62, + 43529990.87, + 46861899.62, + 45446059.32, + 44046586.63, + 45293460.06, + 48772039.02, + 47669719.65, + 47447526.22, + 45884025.5, + 47578624.19, + 47859258.4, + 45515914.52, + 45274339.01, + 43683243.92, + 48015073.36, + 46249598.39, + 46917127.1, + 47416905.45, + 45376432.48, + 46763000.79, + 43792604.96, + 42716522.23, + 42195007.11, + 47211092.36, + 44374292.45, + 45819133.11, + 45855520.28, + 48654605.42, + 48474212.41, + 46438802.12, + 66565730.02, + 49384577.42, + 55558067.21, + 60082608.34, + 76994571.24, + 46041807.5, + 44955394.34, + 42022604.47, + 42080268.78, + 39834180.38, + 46085308.65, + 50007132.23, + 50195096.99, + 45770682.88, + 46860478.29, + 47479695.77, + 46900921.15, + 44993618.49, + 45272511.84, + 53502282.02, + 46629165.55, + 45072409.05, + 43716752.65, + 47124072.67, + 46925658.31, + 46823296.04, + 47892700.9, + 48281635.03, + 49651162.2, + 48412125.56, + 47668135.12, + 46597053.1, + 51253099.81, + 46099747.96, + 46059299.15, + 44097050.8, + ] + ) + test_data = np.array( + [ + 47485712.35, + 47403448.22, + 47355041.71, + 47447284.43, + 47159585.95, + 48329845.27, + 44225749.26, + 44354360.72, + 43734839.24, + 47566467.14, + 46128371.47, + 45122294.9, + 45544036.17, + ] + ) self.train_data = TimeSeries.from_pd(pd.Series(train_data)) - self.test_data = TimeSeries.from_pd(pd.Series(test_data, - index=pd.RangeIndex(start=len(self.train_data), - stop= len(self.train_data)+test_data.shape[0]))) + self.test_data = TimeSeries.from_pd( + pd.Series( + test_data, + index=pd.RangeIndex(start=len(self.train_data), stop=len(self.train_data) + test_data.shape[0]), + ) + ) self.max_forecast_steps = len(self.test_data) self.autoets_model = AutoETS(AutoETSConfig(pval=0.1, max_lag=55, max_forecast_steps=self.max_forecast_steps)) self.ets_model = ETS(ETSConfig(seasonal_periods=4)) @@ -67,6 +192,7 @@ def test_forecast(self): smape_auto = ForecastMetric.sMAPE.value(self.test_data, forecast, target_seq_index=0) logger.info(f"sMAPE = {smape_auto:.4f} for {self.max_forecast_steps} step forecasting for AutoETS") self.assertAlmostEqual(smape_auto, 2.21, delta=1) + self.autoets_model.save(join(rootdir, "tmp", "autoets")) _, _ = self.ets_model.train(self.train_data) forecast, lb, ub = self.ets_model.forecast(self.max_forecast_steps, return_iqr=True) diff --git a/tests/forecast/test_autosarima.py b/tests/forecast/test_autosarima.py index 50bac1f4b..2a485cd9e 100644 --- a/tests/forecast/test_autosarima.py +++ b/tests/forecast/test_autosarima.py @@ -811,8 +811,7 @@ def run_test(self, auto_pqPQ: bool, seasonality_layer: bool, expected_sMAPE: flo # check automatic periodicity detection k = self.test_data.names[0] - m = autosarima_utils.multiperiodicity_detection(self.train_data.univariates[k].np_values) - self.assertEqual(m[0], 24) + self.assertEqual(self.model.base_model.config.seasonal_order[-1], 24) # check the length of forecasting results pred, err = self.model.forecast(self.max_forecast_steps) diff --git a/tests/forecast/test_forecast_ensemble.py b/tests/forecast/test_forecast_ensemble.py index 46075432b..46139146a 100644 --- a/tests/forecast/test_forecast_ensemble.py +++ b/tests/forecast/test_forecast_ensemble.py @@ -15,10 +15,11 @@ from merlion.models.ensemble.forecast import ForecasterEnsemble, ForecasterEnsembleConfig from merlion.models.ensemble.combine import ModelSelector, Mean from merlion.evaluate.forecast import ForecastMetric -from merlion.models.automl.autoprophet import AutoProphet, AutoProphetConfig, PeriodicityStrategy +from merlion.models.automl.autoprophet import AutoProphet, AutoProphetConfig from merlion.models.forecast.arima import Arima, ArimaConfig from merlion.models.factory import ModelFactory from merlion.transform.base import Identity +from merlion.transform.normalize import BoxCoxTransform from merlion.transform.resample import TemporalResample from merlion.utils.data_io import csv_to_time_series, TimeSeries @@ -39,7 +40,7 @@ def _test_mean(self, test_name): model0 = Arima(ArimaConfig(order=(6, 1, 2), max_forecast_steps=50, transform=TemporalResample("1h"))) model1 = Arima(ArimaConfig(order=(24, 1, 0), max_forecast_steps=50, transform=TemporalResample("10min"))) model2 = AutoProphet( - config=AutoProphetConfig(transform=Identity(), periodicity_strategy=PeriodicityStrategy.Max) + config=AutoProphetConfig(transform=Identity(), periodicity_strategy="All", information_criterion="BIC") ) self.ensemble = ForecasterEnsemble( models=[model0, model1, model2], config=ForecasterEnsembleConfig(combiner=Mean(abs_score=False)) @@ -51,9 +52,14 @@ def _test_mean(self, test_name): def _test_selector(self, test_name, expected_smapes): model0 = Arima(ArimaConfig(order=(6, 1, 2), max_forecast_steps=50, transform=TemporalResample("1h"))) - model1 = Arima(ArimaConfig(order=(24, 1, 0), transform=TemporalResample("10min"), max_forecast_steps=50)) + model1 = Arima(ArimaConfig(order=(24, 1, 0), max_forecast_steps=50, transform=TemporalResample("10min"))) model2 = AutoProphet( - config=AutoProphetConfig(target_seq_index=0, transform=Identity(), periodicity_strategy="Max") + config=AutoProphetConfig( + target_seq_index=0, + transform=BoxCoxTransform(lmbda=0), + periodicity_strategy="Max", + information_criterion="AICc", + ) ) self.ensemble = ForecasterEnsemble( config=ForecasterEnsembleConfig( @@ -69,7 +75,7 @@ def _test_selector(self, test_name, expected_smapes): def test_mean(self): print("-" * 80) logger.info("test_mean\n" + "-" * 80 + "\n") - self.expected_smape = 37 + self.expected_smape = 38 self._test_mean(test_name="test_mean") def test_mean_small_train(self): @@ -82,25 +88,25 @@ def test_mean_small_train(self): def test_univariate_selector(self): print("-" * 80) logger.info("test_univariate_selector\n" + "-" * 80 + "\n") - self.expected_smape = 35 - self._test_selector(test_name="test_univariate_selector", expected_smapes=[34.66, 39.81, 30.47]) + self.expected_smape = 20 + self._test_selector(test_name="test_univariate_selector", expected_smapes=[34.66, 39.81, 21.46]) def test_multivariate_selector(self): print("-" * 80) logger.info("test_multivariate_selector\n" + "-" * 80 + "\n") x = self.vals_train.to_pd() - self.expected_smape = 35 + self.expected_smape = 20 self.vals_train = TimeSeries.from_pd( pd.DataFrame(np.concatenate((x.values, x.values * 2), axis=1), columns=["A", "B"], index=x.index) ) - self._test_selector(test_name="test_multivariate_selector", expected_smapes=[34.66, 39.81, 30.47]) + self._test_selector(test_name="test_multivariate_selector", expected_smapes=[34.66, 39.81, 21.46]) def test_selector_small_train(self): print("-" * 80) logger.info("test_selector_small_train\n" + "-" * 80 + "\n") self.vals_train = self.vals_train[-8:] - self.expected_smape = 177 - self._test_selector(test_name="test_selector_small_train", expected_smapes=[np.inf, 7.27, 5.71]) + self.expected_smape = 194 + self._test_selector(test_name="test_selector_small_train", expected_smapes=[np.inf, 7.27, 6.16]) def run_test(self, test_name): logger.info("Training model...") diff --git a/tests/test_custom_dataset.py b/tests/test_custom_dataset.py index 71c317829..0a4b0995c 100644 --- a/tests/test_custom_dataset.py +++ b/tests/test_custom_dataset.py @@ -15,7 +15,7 @@ def test_custom_anom_dataset(): data_dir = os.path.join(rootdir, "data", "synthetic_anomaly") - dataset = CustomAnomalyDataset(root=data_dir, test_frac=0.75, time_unit="s", assume_no_anomaly=True) + dataset = CustomAnomalyDataset(rootdir=data_dir, test_frac=0.75, time_unit="s", assume_no_anomaly=True) assert len(dataset) == len(glob.glob(os.path.join(data_dir, "*.csv"))) assert all("anomaly" in md.columns and "trainval" in md.columns for ts, md in dataset) assert all(abs((~md.trainval).mean() - dataset.test_frac) < 2 / len(ts) for ts, md in dataset) @@ -24,8 +24,10 @@ def test_custom_anom_dataset(): def test_custom_dataset(): csv = os.path.join(rootdir, "data", "walmart", "walmart_mini.csv") index_cols = ["Store", "Dept"] + data_cols = ["Weekly_Sales", "Temperature", "CPI"] df = pd.read_csv(csv, index_col=[0, 1, 2], parse_dates=True) - dataset = CustomDataset(root=csv, test_frac=0.25, time_unit="s", index_cols=index_cols) + dataset = CustomDataset(rootdir=csv, test_frac=0.25, data_cols=data_cols, index_cols=index_cols) assert len(dataset) == len(df.groupby(index_cols).groups) + assert all(list(ts.columns) == data_cols for ts, md in dataset) assert all((c in md.columns for c in ["trainval"] + index_cols) for ts, md in dataset) assert all(abs((~md.trainval).mean() - dataset.test_frac) < 2 / len(ts) for ts, md in dataset) diff --git a/tests/test_plot.py b/tests/test_plot.py index 630b21305..7129d0e27 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -13,7 +13,7 @@ from merlion.transform.base import Identity from merlion.transform.moving_average import DifferenceTransform -from merlion.transform.normalize import PowerTransform +from merlion.transform.normalize import BoxCoxTransform from merlion.transform.resample import TemporalResample from merlion.models.anomaly.forecast_based.prophet import ProphetDetector, ProphetDetectorConfig from merlion.models.forecast.trees import LGBMForecaster, LGBMForecasterConfig @@ -51,7 +51,7 @@ def test_plot_transform_inv(self): print("-" * 80) logger.info("test_plot_transform_inv\n" + "-" * 80 + "\n") self.model = ProphetDetector( - ProphetDetectorConfig(transform=PowerTransform(), invert_transform=True, uncertainty_samples=1000) + ProphetDetectorConfig(transform=BoxCoxTransform(), invert_transform=True, uncertainty_samples=1000) ) self._test_plot(subdir="transform_inv") diff --git a/ts_datasets/ts_datasets/anomaly/__init__.py b/ts_datasets/ts_datasets/anomaly/__init__.py index a80aea7ea..1d7bba19a 100755 --- a/ts_datasets/ts_datasets/anomaly/__init__.py +++ b/ts_datasets/ts_datasets/anomaly/__init__.py @@ -33,7 +33,7 @@ ] -def get_dataset(dataset_name: str, rootdir: str = None) -> TSADBaseDataset: +def get_dataset(dataset_name: str, rootdir: str = None, **kwargs) -> TSADBaseDataset: """ :param dataset_name: the name of the dataset to load, formatted as ```` or ``_``, e.g. ``IOPsCompetition`` @@ -41,6 +41,7 @@ def get_dataset(dataset_name: str, rootdir: str = None) -> TSADBaseDataset: :param rootdir: the directory where the desired dataset is stored. Not required if the package :py:mod:`ts_datasets` is installed in editable mode, i.e. with flag ``-e``. + :param kwargs: keyword arguments for the data loader you are trying to load. :return: the data loader for the desired dataset (and subset) desired """ name_subset = dataset_name.split("_", maxsplit=1) @@ -60,5 +61,6 @@ def get_dataset(dataset_name: str, rootdir: str = None) -> TSADBaseDataset: f"specifying dataset name {dataset_name}." ) - kwargs = dict() if len(name_subset) == 1 else dict(subset=name_subset[1]) + if len(name_subset) > 1: + kwargs.update(subset=name_subset[1]) return cls(rootdir=rootdir, **kwargs) diff --git a/ts_datasets/ts_datasets/anomaly/custom.py b/ts_datasets/ts_datasets/anomaly/custom.py index 4c741b92e..765d12509 100644 --- a/ts_datasets/ts_datasets/anomaly/custom.py +++ b/ts_datasets/ts_datasets/anomaly/custom.py @@ -22,9 +22,18 @@ class CustomAnomalyDataset(CustomDataset, TSADBaseDataset): to get started. """ - def __init__(self, root, test_frac=0.5, assume_no_anomaly=False, time_col=None, time_unit="s", index_cols=None): + def __init__( + self, + rootdir, + test_frac=0.5, + assume_no_anomaly=False, + time_col=None, + time_unit="s", + data_cols=None, + index_cols=None, + ): """ - :param root: Filename of a single CSV, or a directory containing many CSVs. Each CSV must contain 1 + :param rootdir: Filename of a single CSV, or a directory containing many CSVs. Each CSV must contain 1 or more time series. :param test_frac: If we don't find a column "trainval" in the time series, this is the fraction of each time series which we use for testing. @@ -32,6 +41,7 @@ def __init__(self, root, test_frac=0.5, assume_no_anomaly=False, time_col=None, anomalies in the data if this value is ``True``, and we throw an exception if this value is ``False``. :param time_col: Name of the column used to index time. We use the first non-index, non-metadata column if none is given. + :param data_cols: Name of the columns to fetch from the dataset. If ``None``, use all non-time, non-index columns. :param time_unit: If the time column is numerical, we assume it is a timestamp expressed in this unit. :param index_cols: If a CSV file contains multiple time series, these are the columns used to index those time series. For example, a CSV file may contain time series of sales for many (store, department) pairs. @@ -39,7 +49,14 @@ def __init__(self, root, test_frac=0.5, assume_no_anomaly=False, time_col=None, to the metadata of the data loader. """ self.assume_no_anomaly = assume_no_anomaly - super().__init__(root=root, test_frac=test_frac, time_col=time_col, time_unit=time_unit, index_cols=index_cols) + super().__init__( + rootdir=rootdir, + test_frac=test_frac, + time_col=time_col, + time_unit=time_unit, + data_cols=data_cols, + index_cols=index_cols, + ) @property def metadata_cols(self): diff --git a/ts_datasets/ts_datasets/forecast/__init__.py b/ts_datasets/ts_datasets/forecast/__init__.py index f597c7a40..fe3e743ca 100644 --- a/ts_datasets/ts_datasets/forecast/__init__.py +++ b/ts_datasets/ts_datasets/forecast/__init__.py @@ -18,13 +18,14 @@ __all__ = ["get_dataset", "CustomDataset", "M4", "EnergyPower", "SeattleTrail", "SolarPlant"] -def get_dataset(dataset_name: str, rootdir: str = None) -> BaseDataset: +def get_dataset(dataset_name: str, rootdir: str = None, **kwargs) -> BaseDataset: """ :param dataset_name: the name of the dataset to load, formatted as ```` or ``_``, e.g. ``EnergyPower`` or ``M4_Hourly`` :param rootdir: the directory where the desired dataset is stored. Not required if the package :py:mod:`ts_datasets` is installed in editable mode, i.e. with flag ``-e``. + :param kwargs: keyword arguments for the data loader you are trying to load. :return: the data loader for the desired dataset (and subset) desired """ name_subset = dataset_name.split("_", maxsplit=1) @@ -44,5 +45,6 @@ def get_dataset(dataset_name: str, rootdir: str = None) -> BaseDataset: f"specifying dataset name {dataset_name}." ) - kwargs = dict() if len(name_subset) == 1 else dict(subset=name_subset[1]) + if len(name_subset) > 1: + kwargs.update(subset=name_subset[1]) return cls(rootdir=rootdir, **kwargs) diff --git a/ts_datasets/ts_datasets/forecast/custom.py b/ts_datasets/ts_datasets/forecast/custom.py index 8c29b4046..c39f500b7 100644 --- a/ts_datasets/ts_datasets/forecast/custom.py +++ b/ts_datasets/ts_datasets/forecast/custom.py @@ -17,15 +17,16 @@ class CustomDataset(BaseDataset): Wrapper to load a custom dataset. Please review the `tutorial ` to get started. """ - def __init__(self, root, test_frac=0.5, time_col=None, time_unit="s", index_cols=None): + def __init__(self, rootdir, test_frac=0.5, time_col=None, time_unit="s", data_cols=None, index_cols=None): """ - :param root: Filename of a single CSV, or a directory containing many CSVs. Each CSV must contain 1 + :param rootdir: Filename of a single CSV, or a directory containing many CSVs. Each CSV must contain 1 or more time series. :param test_frac: If we don't find a column "trainval" in the time series, this is the fraction of each time series which we use for testing. :param time_col: Name of the column used to index time. We use the first non-index, non-metadata column if none is given. :param time_unit: If the time column is numerical, we assume it is a timestamp expressed in this unit. + :param data_cols: Name of the columns to fetch from the dataset. If ``None``, use all non-time, non-index columns. :param index_cols: If a CSV file contains multiple time series, these are the columns used to index those time series. For example, a CSV file may contain time series of sales for many (store, department) pairs. In this case, ``index_cols`` may be ``["Store", "Dept"]``. The values of the index columns will be added @@ -33,11 +34,12 @@ def __init__(self, root, test_frac=0.5, time_col=None, time_unit="s", index_cols """ super().__init__() assert ( - root is not None and os.path.isfile(root) or os.path.isdir(root) + rootdir is not None and os.path.isfile(rootdir) or os.path.isdir(rootdir) ), "You must give CSV file or directory where the data lives." - csvs = sorted(glob.glob(os.path.join(root, "*.csv*"))) if os.path.isdir(root) else [root] - assert len(csvs) > 0, f"The rootdir {root} must contain at least 1 CSV file. None found." + csvs = sorted(glob.glob(os.path.join(rootdir, "*.csv*"))) if os.path.isdir(rootdir) else [rootdir] + assert len(csvs) > 0, f"The rootdir {rootdir} must contain at least 1 CSV file. None found." self.test_frac = test_frac + self.rootdir = rootdir for csv in csvs: df = pd.read_csv(csv) @@ -54,6 +56,7 @@ def __init__(self, root, test_frac=0.5, time_col=None, time_unit="s", index_cols # Split into multiple time series dataframes based on index df.set_index(index_cols + [time_col], inplace=True) + df = df.loc[:, data_cols] if data_cols is not None else df if len(index_cols) > 0: dfs = [df.loc[idx] for idx in df.groupby(index_cols).groups.values()] else: