7
7
"""
8
8
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM)
9
9
"""
10
+ import random
11
+ from typing import Dict , Any , List
12
+
10
13
try :
11
14
import torch
12
15
import torch .nn as nn
23
26
24
27
from merlion .utils import UnivariateTimeSeries , TimeSeries
25
28
from merlion .models .base import NormalizingConfig
26
- from merlion .models .anomaly .base import DetectorBase , DetectorConfig
29
+ from merlion .models .anomaly .base import DetectorBase , DetectorConfig , MultipleTimeseriesDetectorMixin
27
30
from merlion .post_process .threshold import AggregateAlarms
28
31
from merlion .utils .misc import ProgressBar , initializer
29
32
from merlion .models .anomaly .utils import InputData , batch_detect
@@ -63,7 +66,7 @@ def __init__(
63
66
super ().__init__ (** kwargs )
64
67
65
68
66
- class DAGMM (DetectorBase ):
69
+ class DAGMM (DetectorBase , MultipleTimeseriesDetectorMixin ):
67
70
"""
68
71
Deep autoencoding Gaussian mixture model for anomaly detection (DAGMM).
69
72
DAGMM combines an autoencoder with a Gaussian mixture model to model the distribution
@@ -129,12 +132,13 @@ def _train(self, X):
129
132
data_loader = DataLoader (
130
133
dataset = dataset , batch_size = self .batch_size , shuffle = True , collate_fn = InputData .collate_func
131
134
)
132
- self .dagmm = self ._build_model (X .shape [1 ]).to (self .device )
133
- self .optimizer = torch .optim .Adam (self .dagmm .parameters (), lr = self .lr )
135
+ if self .dagmm is None and self .optimizer is None :
136
+ self .dagmm = self ._build_model (X .shape [1 ]).to (self .device )
137
+ self .optimizer = torch .optim .Adam (self .dagmm .parameters (), lr = self .lr )
138
+ self .dagmm .train ()
134
139
self .data_dim = X .shape [1 ]
135
140
bar = ProgressBar (total = self .num_epochs )
136
141
137
- self .dagmm .train ()
138
142
for epoch in range (self .num_epochs ):
139
143
total_loss , recon_error = 0 , 0
140
144
for input_data in data_loader :
@@ -187,7 +191,6 @@ def train(
187
191
:param post_rule_train_config: The config to use for training the
188
192
model's post-rule. The model's default post-rule train config is
189
193
used if none is supplied here.
190
-
191
194
:return: A `TimeSeries` of the model's anomaly scores on the training
192
195
data.
193
196
"""
@@ -203,6 +206,54 @@ def train(
203
206
)
204
207
return train_scores
205
208
209
+ def train_multiple (
210
+ self , multiple_train_data : List [TimeSeries ], anomaly_labels : List [TimeSeries ] = None ,
211
+ train_config = None , post_rule_train_config = None
212
+ ) -> List [TimeSeries ]:
213
+ """
214
+ Trains the anomaly detector (unsupervised) and its post-rule
215
+ (supervised, if labels are given) on the input multiple time series.
216
+
217
+ :param multiple_train_data: a list of `TimeSeries` of metric values to train the model.
218
+ :param anomaly_labels: a list of `TimeSeries` indicating which timestamps are
219
+ anomalous. Optional.
220
+ :param train_config: Additional training config dict with keys:
221
+
222
+ * | "n_epochs": ``int`` indicating how many times the model must be
223
+ | trained on the timeseries in ``multiple_train_data``. Defaults to 1.
224
+ * | "shuffle": ``bool`` indicating if the ``multiple_train_data`` collection
225
+ | should be shuffled before every epoch. Defaults to True if "n_epochs" > 1.
226
+ :param post_rule_train_config: The config to use for training the
227
+ model's post-rule. The model's default post-rule train config is
228
+ used if none is supplied here.
229
+
230
+ :return: A list of `TimeSeries` of the model's anomaly scores on the training
231
+ data with each element corresponds to time series from ``multiple_train_data``.
232
+ """
233
+ if train_config is None :
234
+ train_config = dict ()
235
+ n_epochs = train_config .pop ("n_epochs" , 1 )
236
+ shuffle = train_config .pop ("shuffle" , n_epochs > 1 )
237
+
238
+ if anomaly_labels is not None :
239
+ assert len (multiple_train_data ) == len (anomaly_labels )
240
+ else :
241
+ anomaly_labels = [None ] * len (multiple_train_data )
242
+ train_scores_list = []
243
+ for _ in range (n_epochs ):
244
+ if shuffle :
245
+ random .shuffle (multiple_train_data )
246
+ for train_data , anomaly_series in zip (multiple_train_data , anomaly_labels ):
247
+ train_scores_list .append (
248
+ self .train (
249
+ train_data = train_data , anomaly_labels = anomaly_series ,
250
+ train_config = train_config , post_rule_train_config = post_rule_train_config
251
+ # FIXME: the post-rule (calibrator and threshold) is trained individually on each time series
252
+ # but ideally it needs to be re-trained on all of the `train_scores_list`
253
+ )
254
+ )
255
+ return train_scores_list
256
+
206
257
def get_anomaly_score (self , time_series : TimeSeries , time_series_prev : TimeSeries = None ) -> TimeSeries :
207
258
"""
208
259
:param time_series: The `TimeSeries` we wish to predict anomaly scores for.
0 commit comments