Skip to content

Commit

Permalink
Add incremental training option to DAGMM model (#65)
Browse files Browse the repository at this point in the history
* Add incremental training option to DAGMM model

* Fix type hints

* Initial implementation of `train_multiple` method using mixins

* Fix docstring

* Address PR comments, fix docs

* Default empty dict for `train_config`

* Add basic unit test for `train_multiple` method

* Revert `train_config` default value

Co-authored-by: Aadyot Bhatnagar <[email protected]>
  • Loading branch information
isenilov and aadyotb authored Feb 4, 2022
1 parent 7af892c commit 37fb75c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 9 deletions.
32 changes: 30 additions & 2 deletions merlion/models/anomaly/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from copy import copy, deepcopy
import inspect
import logging
from typing import Any, Dict
from typing import Any, Dict, List

from scipy.stats import norm

from merlion.models.base import Config, ModelBase
from merlion.models.base import Config, ModelBase, MultipleTimeseriesModelMixin
from merlion.plot import Figure, MTSFigure
from merlion.post_process.calibrate import AnomScoreCalibrator
from merlion.post_process.factory import PostRuleFactory
Expand Down Expand Up @@ -352,3 +352,31 @@ def plot_anomaly_plotly(
scores = f(time_series, time_series_prev=time_series_prev)
fig = MTSFigure(y=time_series, y_prev=time_series_prev, anom=scores)
return fig.plot_plotly(title=title, figsize=figsize)


class MultipleTimeseriesDetectorMixin(MultipleTimeseriesModelMixin):
"""
Abstract mixin for anomaly detectors supporting training on multiple time series.
"""
@abstractmethod
def train_multiple(
self, multiple_train_data: List[TimeSeries], anomaly_labels: List[TimeSeries] = None,
train_config=None, post_rule_train_config=None
) -> List[TimeSeries]:
"""
Trains the anomaly detector (unsupervised) and its post-rule
(supervised, if labels are given) on the input multiple time series.
:param multiple_train_data: a list of `TimeSeries` of metric values to train the model.
:param anomaly_labels: a list of `TimeSeries` indicating which timestamps are
anomalous. Optional.
:param train_config: Additional training configs, if needed. Only
required for some models.
:param post_rule_train_config: The config to use for training the
model's post-rule. The model's default post-rule train config is
used if none is supplied here.
:return: A list of `TimeSeries` of the model's anomaly scores on the training
data with each element corresponds to time series from ``multiple_train_data``.
"""
raise NotImplementedError
63 changes: 57 additions & 6 deletions merlion/models/anomaly/dagmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"""
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM)
"""
import random
from typing import Dict, Any, List

try:
import torch
import torch.nn as nn
Expand All @@ -23,7 +26,7 @@

from merlion.utils import UnivariateTimeSeries, TimeSeries
from merlion.models.base import NormalizingConfig
from merlion.models.anomaly.base import DetectorBase, DetectorConfig
from merlion.models.anomaly.base import DetectorBase, DetectorConfig, MultipleTimeseriesDetectorMixin
from merlion.post_process.threshold import AggregateAlarms
from merlion.utils.misc import ProgressBar, initializer
from merlion.models.anomaly.utils import InputData, batch_detect
Expand Down Expand Up @@ -63,7 +66,7 @@ def __init__(
super().__init__(**kwargs)


class DAGMM(DetectorBase):
class DAGMM(DetectorBase, MultipleTimeseriesDetectorMixin):
"""
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM).
DAGMM combines an autoencoder with a Gaussian mixture model to model the distribution
Expand Down Expand Up @@ -129,12 +132,13 @@ def _train(self, X):
data_loader = DataLoader(
dataset=dataset, batch_size=self.batch_size, shuffle=True, collate_fn=InputData.collate_func
)
self.dagmm = self._build_model(X.shape[1]).to(self.device)
self.optimizer = torch.optim.Adam(self.dagmm.parameters(), lr=self.lr)
if self.dagmm is None and self.optimizer is None:
self.dagmm = self._build_model(X.shape[1]).to(self.device)
self.optimizer = torch.optim.Adam(self.dagmm.parameters(), lr=self.lr)
self.dagmm.train()
self.data_dim = X.shape[1]
bar = ProgressBar(total=self.num_epochs)

self.dagmm.train()
for epoch in range(self.num_epochs):
total_loss, recon_error = 0, 0
for input_data in data_loader:
Expand Down Expand Up @@ -187,7 +191,6 @@ def train(
:param post_rule_train_config: The config to use for training the
model's post-rule. The model's default post-rule train config is
used if none is supplied here.
:return: A `TimeSeries` of the model's anomaly scores on the training
data.
"""
Expand All @@ -203,6 +206,54 @@ def train(
)
return train_scores

def train_multiple(
self, multiple_train_data: List[TimeSeries], anomaly_labels: List[TimeSeries] = None,
train_config=None, post_rule_train_config=None
) -> List[TimeSeries]:
"""
Trains the anomaly detector (unsupervised) and its post-rule
(supervised, if labels are given) on the input multiple time series.
:param multiple_train_data: a list of `TimeSeries` of metric values to train the model.
:param anomaly_labels: a list of `TimeSeries` indicating which timestamps are
anomalous. Optional.
:param train_config: Additional training config dict with keys:
* | "n_epochs": ``int`` indicating how many times the model must be
| trained on the timeseries in ``multiple_train_data``. Defaults to 1.
* | "shuffle": ``bool`` indicating if the ``multiple_train_data`` collection
| should be shuffled before every epoch. Defaults to True if "n_epochs" > 1.
:param post_rule_train_config: The config to use for training the
model's post-rule. The model's default post-rule train config is
used if none is supplied here.
:return: A list of `TimeSeries` of the model's anomaly scores on the training
data with each element corresponds to time series from ``multiple_train_data``.
"""
if train_config is None:
train_config = dict()
n_epochs = train_config.pop("n_epochs", 1)
shuffle = train_config.pop("shuffle", n_epochs > 1)

if anomaly_labels is not None:
assert len(multiple_train_data) == len(anomaly_labels)
else:
anomaly_labels = [None] * len(multiple_train_data)
train_scores_list = []
for _ in range(n_epochs):
if shuffle:
random.shuffle(multiple_train_data)
for train_data, anomaly_series in zip(multiple_train_data, anomaly_labels):
train_scores_list.append(
self.train(
train_data=train_data, anomaly_labels=anomaly_series,
train_config=train_config, post_rule_train_config=post_rule_train_config
# FIXME: the post-rule (calibrator and threshold) is trained individually on each time series
# but ideally it needs to be re-trained on all of the `train_scores_list`
)
)
return train_scores_list

def get_anomaly_score(self, time_series: TimeSeries, time_series_prev: TimeSeries = None) -> TimeSeries:
"""
:param time_series: The `TimeSeries` we wish to predict anomaly scores for.
Expand Down
18 changes: 17 additions & 1 deletion merlion/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import os
from os.path import abspath, join
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List

import dill
import pandas as pd
Expand Down Expand Up @@ -442,3 +442,19 @@ def __deepcopy__(self, memodict={}):
state_dict.pop("config", None)
new_model.__setstate__(state_dict)
return new_model


class MultipleTimeseriesModelMixin(metaclass=AutodocABCMeta):
"""
Abstract mixin for models supporting training on multiple time series.
"""
@abstractmethod
def train_multiple(self, multiple_train_data: List[TimeSeries], train_config=None):
"""
Trains the model on multiple time series, optionally with some
additional implementation-specific config options ``train_config``.
:param multiple_train_data: a list of `TimeSeries` to use as a training set
:param train_config: additional configurations (if needed)
"""
raise NotImplementedError
3 changes: 3 additions & 0 deletions tests/anomaly/multivariate/test_dagmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(self, *args, **kwargs):
train_ts = TimeSeries.from_pd(self.train_df)
self.model.train(train_ts)

logger.info("Training multiple timeseries model...\n")
self.model.train_multiple([train_ts] * 10)

def test_score(self):
print("-" * 80)
logger.info("test_score\n" + "-" * 80 + "\n")
Expand Down

0 comments on commit 37fb75c

Please sign in to comment.