Skip to content

Commit 37fb75c

Browse files
isenilovaadyotb
andauthored
Add incremental training option to DAGMM model (#65)
* Add incremental training option to DAGMM model * Fix type hints * Initial implementation of `train_multiple` method using mixins * Fix docstring * Address PR comments, fix docs * Default empty dict for `train_config` * Add basic unit test for `train_multiple` method * Revert `train_config` default value Co-authored-by: Aadyot Bhatnagar <[email protected]>
1 parent 7af892c commit 37fb75c

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

merlion/models/anomaly/base.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from copy import copy, deepcopy
1212
import inspect
1313
import logging
14-
from typing import Any, Dict
14+
from typing import Any, Dict, List
1515

1616
from scipy.stats import norm
1717

18-
from merlion.models.base import Config, ModelBase
18+
from merlion.models.base import Config, ModelBase, MultipleTimeseriesModelMixin
1919
from merlion.plot import Figure, MTSFigure
2020
from merlion.post_process.calibrate import AnomScoreCalibrator
2121
from merlion.post_process.factory import PostRuleFactory
@@ -352,3 +352,31 @@ def plot_anomaly_plotly(
352352
scores = f(time_series, time_series_prev=time_series_prev)
353353
fig = MTSFigure(y=time_series, y_prev=time_series_prev, anom=scores)
354354
return fig.plot_plotly(title=title, figsize=figsize)
355+
356+
357+
class MultipleTimeseriesDetectorMixin(MultipleTimeseriesModelMixin):
358+
"""
359+
Abstract mixin for anomaly detectors supporting training on multiple time series.
360+
"""
361+
@abstractmethod
362+
def train_multiple(
363+
self, multiple_train_data: List[TimeSeries], anomaly_labels: List[TimeSeries] = None,
364+
train_config=None, post_rule_train_config=None
365+
) -> List[TimeSeries]:
366+
"""
367+
Trains the anomaly detector (unsupervised) and its post-rule
368+
(supervised, if labels are given) on the input multiple time series.
369+
370+
:param multiple_train_data: a list of `TimeSeries` of metric values to train the model.
371+
:param anomaly_labels: a list of `TimeSeries` indicating which timestamps are
372+
anomalous. Optional.
373+
:param train_config: Additional training configs, if needed. Only
374+
required for some models.
375+
:param post_rule_train_config: The config to use for training the
376+
model's post-rule. The model's default post-rule train config is
377+
used if none is supplied here.
378+
379+
:return: A list of `TimeSeries` of the model's anomaly scores on the training
380+
data with each element corresponds to time series from ``multiple_train_data``.
381+
"""
382+
raise NotImplementedError

merlion/models/anomaly/dagmm.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
"""
88
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM)
99
"""
10+
import random
11+
from typing import Dict, Any, List
12+
1013
try:
1114
import torch
1215
import torch.nn as nn
@@ -23,7 +26,7 @@
2326

2427
from merlion.utils import UnivariateTimeSeries, TimeSeries
2528
from merlion.models.base import NormalizingConfig
26-
from merlion.models.anomaly.base import DetectorBase, DetectorConfig
29+
from merlion.models.anomaly.base import DetectorBase, DetectorConfig, MultipleTimeseriesDetectorMixin
2730
from merlion.post_process.threshold import AggregateAlarms
2831
from merlion.utils.misc import ProgressBar, initializer
2932
from merlion.models.anomaly.utils import InputData, batch_detect
@@ -63,7 +66,7 @@ def __init__(
6366
super().__init__(**kwargs)
6467

6568

66-
class DAGMM(DetectorBase):
69+
class DAGMM(DetectorBase, MultipleTimeseriesDetectorMixin):
6770
"""
6871
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM).
6972
DAGMM combines an autoencoder with a Gaussian mixture model to model the distribution
@@ -129,12 +132,13 @@ def _train(self, X):
129132
data_loader = DataLoader(
130133
dataset=dataset, batch_size=self.batch_size, shuffle=True, collate_fn=InputData.collate_func
131134
)
132-
self.dagmm = self._build_model(X.shape[1]).to(self.device)
133-
self.optimizer = torch.optim.Adam(self.dagmm.parameters(), lr=self.lr)
135+
if self.dagmm is None and self.optimizer is None:
136+
self.dagmm = self._build_model(X.shape[1]).to(self.device)
137+
self.optimizer = torch.optim.Adam(self.dagmm.parameters(), lr=self.lr)
138+
self.dagmm.train()
134139
self.data_dim = X.shape[1]
135140
bar = ProgressBar(total=self.num_epochs)
136141

137-
self.dagmm.train()
138142
for epoch in range(self.num_epochs):
139143
total_loss, recon_error = 0, 0
140144
for input_data in data_loader:
@@ -187,7 +191,6 @@ def train(
187191
:param post_rule_train_config: The config to use for training the
188192
model's post-rule. The model's default post-rule train config is
189193
used if none is supplied here.
190-
191194
:return: A `TimeSeries` of the model's anomaly scores on the training
192195
data.
193196
"""
@@ -203,6 +206,54 @@ def train(
203206
)
204207
return train_scores
205208

209+
def train_multiple(
210+
self, multiple_train_data: List[TimeSeries], anomaly_labels: List[TimeSeries] = None,
211+
train_config=None, post_rule_train_config=None
212+
) -> List[TimeSeries]:
213+
"""
214+
Trains the anomaly detector (unsupervised) and its post-rule
215+
(supervised, if labels are given) on the input multiple time series.
216+
217+
:param multiple_train_data: a list of `TimeSeries` of metric values to train the model.
218+
:param anomaly_labels: a list of `TimeSeries` indicating which timestamps are
219+
anomalous. Optional.
220+
:param train_config: Additional training config dict with keys:
221+
222+
* | "n_epochs": ``int`` indicating how many times the model must be
223+
| trained on the timeseries in ``multiple_train_data``. Defaults to 1.
224+
* | "shuffle": ``bool`` indicating if the ``multiple_train_data`` collection
225+
| should be shuffled before every epoch. Defaults to True if "n_epochs" > 1.
226+
:param post_rule_train_config: The config to use for training the
227+
model's post-rule. The model's default post-rule train config is
228+
used if none is supplied here.
229+
230+
:return: A list of `TimeSeries` of the model's anomaly scores on the training
231+
data with each element corresponds to time series from ``multiple_train_data``.
232+
"""
233+
if train_config is None:
234+
train_config = dict()
235+
n_epochs = train_config.pop("n_epochs", 1)
236+
shuffle = train_config.pop("shuffle", n_epochs > 1)
237+
238+
if anomaly_labels is not None:
239+
assert len(multiple_train_data) == len(anomaly_labels)
240+
else:
241+
anomaly_labels = [None] * len(multiple_train_data)
242+
train_scores_list = []
243+
for _ in range(n_epochs):
244+
if shuffle:
245+
random.shuffle(multiple_train_data)
246+
for train_data, anomaly_series in zip(multiple_train_data, anomaly_labels):
247+
train_scores_list.append(
248+
self.train(
249+
train_data=train_data, anomaly_labels=anomaly_series,
250+
train_config=train_config, post_rule_train_config=post_rule_train_config
251+
# FIXME: the post-rule (calibrator and threshold) is trained individually on each time series
252+
# but ideally it needs to be re-trained on all of the `train_scores_list`
253+
)
254+
)
255+
return train_scores_list
256+
206257
def get_anomaly_score(self, time_series: TimeSeries, time_series_prev: TimeSeries = None) -> TimeSeries:
207258
"""
208259
:param time_series: The `TimeSeries` we wish to predict anomaly scores for.

merlion/models/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import os
1515
from os.path import abspath, join
16-
from typing import Any, Dict, Optional, Tuple
16+
from typing import Any, Dict, Optional, Tuple, List
1717

1818
import dill
1919
import pandas as pd
@@ -442,3 +442,19 @@ def __deepcopy__(self, memodict={}):
442442
state_dict.pop("config", None)
443443
new_model.__setstate__(state_dict)
444444
return new_model
445+
446+
447+
class MultipleTimeseriesModelMixin(metaclass=AutodocABCMeta):
448+
"""
449+
Abstract mixin for models supporting training on multiple time series.
450+
"""
451+
@abstractmethod
452+
def train_multiple(self, multiple_train_data: List[TimeSeries], train_config=None):
453+
"""
454+
Trains the model on multiple time series, optionally with some
455+
additional implementation-specific config options ``train_config``.
456+
457+
:param multiple_train_data: a list of `TimeSeries` to use as a training set
458+
:param train_config: additional configurations (if needed)
459+
"""
460+
raise NotImplementedError

tests/anomaly/multivariate/test_dagmm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(self, *args, **kwargs):
4747
train_ts = TimeSeries.from_pd(self.train_df)
4848
self.model.train(train_ts)
4949

50+
logger.info("Training multiple timeseries model...\n")
51+
self.model.train_multiple([train_ts] * 10)
52+
5053
def test_score(self):
5154
print("-" * 80)
5255
logger.info("test_score\n" + "-" * 80 + "\n")

0 commit comments

Comments
 (0)