Skip to content

Commit

Permalink
Make AutoProphet more sophisticated & simplify AutoETS code. (#121)
Browse files Browse the repository at this point in the history
* Enhancments to layered forecasting detectors.

* Automatically choose add/mul mode for AutoProphet.

* Older acceptable prophet version for py36.

This is for compatibility with CentOS images which may have older
versions of GCC which can't compile pystan 2.19.

* Simplify train_post_process()

* Simplify autoprophet implementation using _train()

* Allow auto-detection of best lambda for Box-Cox.

* Rename PowerTransform to BoxCoxTransform.

* Simplify AutoETS implementation.

* Allow use of custom datasets with benchmark code.

* Fix boxcox bugs.

* Update version to 1.3.0.

* Update expected test result.

* Move seasonality detection to SeasonalityLayer.

* Simplify serialization of enums in config objects.

* Use info criterion for AutoProphet model selection

We also unify the implementations of AutoETS and AutoProphet.

* Create base class for IC-based model selection.

* Minor bugfix.
  • Loading branch information
aadyotb committed Sep 1, 2022
1 parent 344a3e5 commit eae7bd4
Show file tree
Hide file tree
Showing 31 changed files with 635 additions and 364 deletions.
41 changes: 23 additions & 18 deletions benchmark_anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def parse_args():
"in ts_datasets/ts_datasets/anomaly/__init__.py for "
"valid options.",
)
parser.add_argument("--data_root", default=None, help="Root directory/file of dataset.")
parser.add_argument("--data_kwargs", default="{}", help="JSON of keyword arguemtns for the data loader.")
parser.add_argument(
"--models",
type=str,
Expand Down Expand Up @@ -148,6 +150,8 @@ def parse_args():
args.metric = TSADMetric[args.metric]
args.pointwise_metric = TSADMetric[args.pointwise_metric]
args.visualize = args.visualize and not args.eval_only
args.data_kwargs = json.loads(args.data_kwargs)
assert isinstance(args.data_kwargs, dict)
if args.retrain_freq.lower() in ["", "none", "null"]:
args.retrain_freq = None
elif args.retrain_freq != "default":
Expand All @@ -164,10 +168,14 @@ def parse_args():
return args


def dataset_to_name(dataset: TSADBaseDataset):
if dataset.subset is not None:
return f"{type(dataset).__name__}_{dataset.subset}"
return type(dataset).__name__
def get_dataset_name(dataset: TSADBaseDataset):
name = type(dataset).__name__
if hasattr(dataset, "subset") and dataset.subset is not None:
name += "_" + dataset.subset
if isinstance(dataset, CustomAnomalyDataset):
root = dataset.rootdir
name = os.path.join(name, os.path.basename(os.path.dirname(root) if os.path.isfile(root) else root))
return name


def dataset_to_threshold(dataset: TSADBaseDataset, tune_on_test=False):
Expand Down Expand Up @@ -269,14 +277,13 @@ def train_model(
unsupervised=False,
tune_on_test=False,
):
"""Trains a model on the time series dataset given, and save their predictions
to a dataset."""
"""Trains a model on the time series dataset given, and save their predictions to a dataset."""
resampler = None
if isinstance(dataset, IOpsCompetition):
resampler = TemporalResample("5min")

model_name = resolve_model_name(model_name)
dataset_name = dataset_to_name(dataset)
dataset_name = get_dataset_name(dataset)
model_dir = model_name if retrain_freq is None else f"{model_name}_{retrain_freq}"
dirname = os.path.join("results", "anomaly", model_dir)
csv = os.path.join(dirname, f"pred_{dataset_name}.csv.gz")
Expand Down Expand Up @@ -337,16 +344,17 @@ def train_model(
if not visualize:
if i == i0 == 0:
os.makedirs(os.path.dirname(csv), exist_ok=True)
os.makedirs(os.path.dirname(checkpoint), exist_ok=True)
df = pd.DataFrame({"timestamp": [], "y": [], "trainval": [], "idx": []})
df.to_csv(csv, index=False)

df = pd.read_csv(csv)
ts_df = train_scores.to_pd().append(test_scores.to_pd())
ts_df = pd.concat((train_scores.to_pd(), test_scores.to_pd()))
ts_df.columns = ["y"]
ts_df.loc[:, "timestamp"] = ts_df.index.view(int) // 1e9
ts_df.loc[:, "trainval"] = [j < len(train_scores) for j in range(len(ts_df))]
ts_df.loc[:, "idx"] = i
df = df.append(ts_df, ignore_index=True)
df = pd.concat((df, ts_df), ignore_index=True)
df.to_csv(csv, index=False)

# Start from time series i+1 if loading a checkpoint.
Expand All @@ -358,7 +366,7 @@ def train_model(
score = test_scores if tune_on_test else train_scores
label = test_anom if tune_on_test else train_anom
model.train_post_process(
train_vals, train_result=score, anomaly_labels=label, post_rule_train_config=post_rule_train_config
train_result=score, anomaly_labels=label, post_rule_train_config=post_rule_train_config
)

# Log (many) evaluation metrics for the time series
Expand Down Expand Up @@ -433,7 +441,7 @@ def read_model_predictions(dataset: TSADBaseDataset, model_dir: str):
Returns a list of lists all_preds, where all_preds[i] is the model's raw
anomaly scores for time series i in the dataset.
"""
csv = os.path.join("results", "anomaly", model_dir, f"pred_{dataset_to_name(dataset)}.csv.gz")
csv = os.path.join("results", "anomaly", model_dir, f"pred_{get_dataset_name(dataset)}.csv.gz")
preds = pd.read_csv(csv, dtype={"trainval": bool, "idx": int})
preds["timestamp"] = to_pd_datetime(preds["timestamp"])
return [preds[preds["idx"] == i].set_index("timestamp") for i in sorted(preds["idx"].unique())]
Expand All @@ -460,7 +468,6 @@ def evaluate_predictions(
for i, (true, md) in enumerate(tqdm(dataset)):
# Get time series for the train & test splits of the ground truth
idx = ~md.trainval if tune_on_test else md.trainval
train_vals = df_to_merlion(true[idx], md[idx], transform=resampler)
true_train = df_to_merlion(true[idx], md[idx], get_ground_truth=True)
true_test = df_to_merlion(true[~md.trainval], md[~md.trainval], get_ground_truth=True)

Expand Down Expand Up @@ -501,9 +508,7 @@ def evaluate_predictions(
m.threshold = m.threshold.to_simple_threshold()
if tune_on_test and not unsupervised:
m.calibrator.train(TimeSeries.from_pd(og_pred["y"][og_pred["trainval"]]))
m.train_post_process(
train_vals, train_result=train, anomaly_labels=true_train, post_rule_train_config=prtc
)
m.train_post_process(train_result=train, anomaly_labels=true_train, post_rule_train_config=prtc)
models.append(m)

# Get the lead & lag time for the dataset
Expand Down Expand Up @@ -538,7 +543,6 @@ def evaluate_predictions(
model.threshold = model.threshold.to_simple_threshold()
model.threshold.alm_threshold = threshold
model.train_post_process(
train_vals,
train_result=pred_train,
anomaly_labels=true_train,
post_rule_train_config=ensemble_threshold_train_config,
Expand Down Expand Up @@ -653,7 +657,7 @@ def main():
logging.basicConfig(
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=level
)
dataset = get_dataset(args.dataset)
dataset = get_dataset(args.dataset, rootdir=args.data_root, **args.data_kwargs)
retrain_freq, train_window = args.retrain_freq, args.train_window
univariate = dataset[0][0].shape[1] == 1
if retrain_freq == "default":
Expand Down Expand Up @@ -697,10 +701,11 @@ def main():
)

model_name = "+".join(sorted(resolve_model_name(m) for m in args.models))
summary = os.path.join("results", "anomaly", f"{dataset_to_name(dataset)}_summary.csv")
summary = os.path.join("results", "anomaly", f"{get_dataset_name(dataset)}_summary.csv")
if os.path.exists(summary):
df = pd.read_csv(summary, index_col=0)
else:
os.makedirs(os.path.dirname(summary), exist_ok=True)
df = pd.DataFrame()
if retrain_freq:
model_name += f"_{retrain_freq}"
Expand Down
27 changes: 18 additions & 9 deletions benchmark_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def parse_args():
"in ts_datasets/ts_datasets/forecast/__init__.py for "
"valid options.",
)
parser.add_argument("--data_root", default=None, help="Root directory/file of dataset.")
parser.add_argument("--data_kwargs", default="{}", help="JSON of keyword arguments for the data loader.")
parser.add_argument(
"--models",
type=str,
Expand Down Expand Up @@ -122,6 +124,8 @@ def parse_args():
)

args = parser.parse_args()
args.data_kwargs = json.loads(args.data_kwargs)
assert isinstance(args.data_kwargs, dict)

# If not summarizing all results, we need at least one model to evaluate
if args.summarize and args.models is None:
Expand All @@ -139,6 +143,9 @@ def get_dataset_name(dataset: BaseDataset) -> str:
name = type(dataset).__name__
if hasattr(dataset, "subset") and dataset.subset is not None:
name += "_" + dataset.subset
if isinstance(dataset, CustomDataset):
root = dataset.rootdir
name = os.path.join(name, os.path.basename(os.path.dirname(root) if os.path.isfile(root) else root))
return name


Expand Down Expand Up @@ -238,6 +245,7 @@ def train_model(
i0 = pd.read_csv(csv).idx.max()
else:
i0 = -1
os.makedirs(os.path.dirname(csv), exist_ok=True)
with open(csv, "w") as f:
f.write("idx,name,horizon,retrain_type,n_retrain,RMSE,sMAPE\n")

Expand Down Expand Up @@ -422,7 +430,6 @@ def join_dfs(name2df: Dict[str, pd.DataFrame]) -> pd.DataFrame:
def summarize_full_df(full_df: pd.DataFrame) -> pd.DataFrame:
# Get the names of all algorithms which have full results
algs = [col[len("sMAPE") :] for col in full_df.columns if col.startswith("sMAPE") and not full_df[col].isna().any()]

summary_df = pd.DataFrame({alg.lstrip("_"): [] for alg in algs})

# Compute pooled (per time series) mean/median sMAPE, RMSE
Expand Down Expand Up @@ -454,7 +461,7 @@ def main():
stream=sys.stdout,
level=logging.DEBUG if args.debug else logging.INFO,
)
dataset = get_dataset(args.dataset)
dataset = get_dataset(args.dataset, rootdir=args.data_root, **args.data_kwargs)
dataset_name = get_dataset_name(dataset)

if len(args.models) > 0:
Expand Down Expand Up @@ -507,25 +514,27 @@ def main():
f"before trying to summarize their results."
)
for csv in sorted(csvs):
model_name = os.path.basename(os.path.dirname(csv))
suffix = re.search(f"(?<={dataset_name}).*(?=\\.csv)", os.path.basename(csv)).group(0)
basename = re.search(f"{dataset_name}.*\\.csv", csv).group(0)
model_name = os.path.basename(os.path.dirname(csv[: -len(basename)]))
suffix = re.search(f"(?<={dataset_name}).*(?=\\.csv)", basename).group(0)
try:
name2df[model_name + suffix] = pd.read_csv(csv)
except Exception as e:
logger.warning(f'Caught {type(e).__name__}: "{e}". Skipping csv file {csv}.')
continue

# Join all the dataframes into one
# Join all the dataframes into one & summarize the results
dirname = os.path.join(MERLION_ROOT, "results", "forecast")
full_df = join_dfs(name2df)
full_df.to_csv(os.path.join(dirname, f"{dataset_name}_full.csv"), index=False)

# Summarize the joined dataframe
summary_df = summarize_full_df(full_df)
summary_df.to_csv(os.path.join(dirname, f"{dataset_name}_summary.csv"), index=True)
if args.summarize:
print(summary_df)

full_fname, summary_fname = [os.path.join(dirname, f"{dataset_name}_{x}.csv") for x in ["full", "summary"]]
os.makedirs(os.path.dirname(full_fname), exist_ok=True)
full_df.to_csv(full_fname, index=False)
summary_df.to_csv(summary_fname, index=True)


if __name__ == "__main__":
main()
26 changes: 23 additions & 3 deletions examples/CustomDataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@
"source": [
"from ts_datasets.anomaly import CustomAnomalyDataset\n",
"dataset = CustomAnomalyDataset(\n",
" root=anom_dir, # where the data is stored\n",
" rootdir=anom_dir, # where the data is stored\n",
" test_frac=0.75, # use 75% of each time series for testing. \n",
" # overridden if the column `trainval` is in the actual CSV.\n",
" time_unit=\"s\", # the timestamp column (automatically detected) is in units of seconds\n",
Expand Down Expand Up @@ -956,7 +956,7 @@
"source": [
"from ts_datasets.forecast import CustomDataset\n",
"dataset = CustomDataset(\n",
" root=csv, # where the data is stored\n",
" rootdir=csv, # where the data is stored\n",
" index_cols=[\"Store\", \"Dept\"], # Individual time series are indexed by store & department\n",
" test_frac=0.75, # use 25% of each time series for testing. \n",
" # overridden if the column `trainval` is in the actual CSV.\n",
Expand Down Expand Up @@ -1409,7 +1409,27 @@
"source": [
"## Broader Takeaways\n",
"\n",
"In general, a dataset can contain any number of CSVs stored under a single root directory. Each CSV can contain one or more time series, where the different time series within a single file are indicated by different values of the index column. Note that this works for anomaly detection as well! You just need to make sure that your CSVs all contain the `anomaly` column. In general, all features supported by `CustomDataset` are also supported by `CustomAnomalyDataset`, as long as your CSV files have the `anomaly` column."
"In general, a dataset can contain any number of CSVs stored under a single root directory. Each CSV can contain one or more time series, where the different time series within a single file are indicated by different values of the index column. Note that this works for anomaly detection as well! You just need to make sure that your CSVs all contain the `anomaly` column. In general, all features supported by `CustomDataset` are also supported by `CustomAnomalyDataset`, as long as your CSV files have the `anomaly` column.\n",
"\n",
"If you want to either of the above custom datasets for benchmarking, you can call\n",
"\n",
"```\n",
"python benchmark_anomaly.py --model IsolationForest --retrain_freq 7d \\\n",
" --dataset CustomAnomalyDataset --data_root data/synthetic_anomaly \\\n",
" --data_kwargs '{\"assume_no_anomaly\": true, \"test_frac\": 0.75}'\n",
"```\n",
"\n",
"or \n",
"\n",
"```\n",
"python benchmark_forecast.py --model AutoETS \\\n",
" --dataset CustomDataset --data_root data/walmart/walmart_mini.csv \\\n",
" --data_kwargs '{\"test_frac\": 0.25, \\\n",
" \"index_cols\": [\"Store\", \"Dept\"], \\\n",
" \"data_cols\": [\"Weekly_Sales\"]}'\n",
"```\n",
"\n",
"Note in the example above, we specify \"data_cols\" as \"Weekly_Sales\". This indicates that we want"
]
}
],
Expand Down
6 changes: 3 additions & 3 deletions merlion/evaluate/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
:param lb (optional): lower bound of 95% prediction interval. This value is used for computing MSIS
:param target_seq_index (optional): the index of the target sequence, for multivariate.
"""
ground_truth = ground_truth.to_ts() if isinstance(ground_truth, UnivariateTimeSeries) else ground_truth
predict = predict.to_ts() if isinstance(predict, UnivariateTimeSeries) else predict
insample = insample.to_ts() if isinstance(insample, UnivariateTimeSeries) else insample
ground_truth = TimeSeries.from_pd(ground_truth)
predict = TimeSeries.from_pd(predict)
insample = TimeSeries.from_pd(insample)
t0, tf = predict.t0, predict.tf
ground_truth = ground_truth.window(t0, tf, include_tf=True).align()
if target_seq_index is not None:
Expand Down
7 changes: 1 addition & 6 deletions merlion/models/anomaly/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,11 @@ def train(
)

def train_post_process(
self,
train_data: TimeSeries,
train_result: Union[TimeSeries, pd.DataFrame],
anomaly_labels=None,
post_rule_train_config=None,
self, train_result: Union[TimeSeries, pd.DataFrame], anomaly_labels=None, post_rule_train_config=None
) -> TimeSeries:
"""
Converts the train result (anom scores on train data) into a TimeSeries object and trains the post-rule.
:param train_data: a `TimeSeries` of metric values to train the model.
:param train_result: Raw anomaly scores on the training data.
:param anomaly_labels: a `TimeSeries` indicating which timestamps are anomalous. Optional.
:param post_rule_train_config: The config to use for training the model's post-rule. The model's default
Expand Down
7 changes: 0 additions & 7 deletions merlion/models/anomaly/change_point/bocpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,6 @@ def __init__(
self.lag = lag
super().__init__(max_forecast_steps=max_forecast_steps, **kwargs)

def to_dict(self, _skipped_keys=None):
_skipped_keys = _skipped_keys if _skipped_keys is not None else set()
config_dict = super().to_dict(_skipped_keys.union({"change_kind"}))
if "change_kind" not in _skipped_keys:
config_dict["change_kind"] = self.change_kind.name
return config_dict

@property
def change_kind(self) -> ChangeKind:
return self._change_kind
Expand Down
11 changes: 5 additions & 6 deletions merlion/models/anomaly/forecast_based/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Base class for anomaly detectors based on forecasting models.
"""
import logging
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -65,17 +65,16 @@ def forecast_to_anom_score(

def train_post_process(
self,
train_data: TimeSeries,
train_result: Tuple[pd.DataFrame, pd.DataFrame],
train_result: Tuple[Union[TimeSeries, pd.DataFrame], Optional[Union[TimeSeries, pd.DataFrame]]],
anomaly_labels=None,
post_rule_train_config=None,
) -> TimeSeries:
if isinstance(train_result, tuple) and len(train_result) == 2:
train_pred, train_err = ForecasterBase.train_post_process(self, train_data, train_result)
train_data = train_data if self.invert_transform else self.transform(train_data)
train_pred, train_err = ForecasterBase.train_post_process(self, train_result)
train_data = self.train_data if self.invert_transform else self.transform(self.train_data)
train_result = self.forecast_to_anom_score(train_data, train_pred, train_err)
return DetectorBase.train_post_process(
self, train_data, train_result, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config
self, train_result, anomaly_labels=anomaly_labels, post_rule_train_config=post_rule_train_config
)

def train(
Expand Down
4 changes: 1 addition & 3 deletions merlion/models/anomaly/forecast_based/prophet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021 salesforce.com, inc.
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Expand All @@ -12,11 +12,9 @@
from merlion.models.anomaly.base import DetectorConfig
from merlion.models.forecast.prophet import ProphetConfig, Prophet
from merlion.post_process.threshold import AggregateAlarms
from merlion.transform.moving_average import DifferenceTransform


class ProphetDetectorConfig(ProphetConfig, DetectorConfig):
_default_transform = DifferenceTransform()
_default_threshold = AggregateAlarms(alm_threshold=3)


Expand Down
Loading

0 comments on commit eae7bd4

Please sign in to comment.