diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9337d0f6..af8092293 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: '19.3b0' + rev: '22.10.0' hooks: - id: black args: ["--line-length", "120"] diff --git a/docs/source/merlion.models.anomaly.forecast_based.rst b/docs/source/merlion.models.anomaly.forecast_based.rst index 1ea5c5dee..eda0d73db 100644 --- a/docs/source/merlion.models.anomaly.forecast_based.rst +++ b/docs/source/merlion.models.anomaly.forecast_based.rst @@ -12,7 +12,6 @@ anomaly.forecast\_based sarima ets prophet - lstm mses anomaly.forecast\_based.base @@ -55,14 +54,6 @@ anomaly.forecast\_based.prophet :undoc-members: :show-inheritance: -anomaly.forecast\_based.lstm ----------------------------- - -.. automodule:: merlion.models.anomaly.forecast_based.lstm - :members: - :undoc-members: - :show-inheritance: - anomaly.forecast\_based.mses ---------------------------- diff --git a/docs/source/merlion.models.forecast.rst b/docs/source/merlion.models.forecast.rst index f41aa4510..151aebf1d 100644 --- a/docs/source/merlion.models.forecast.rst +++ b/docs/source/merlion.models.forecast.rst @@ -10,6 +10,7 @@ Base classes: .. autosummary:: base + deep_base sklearn_base Univariate models: @@ -20,13 +21,18 @@ Univariate models: ets prophet smoother - lstm `Multivariate ` models: .. autosummary:: vector_ar trees + deep_ar + autoformer + etsformer + informer + transformer + `Exogenous regressor ` models: @@ -37,6 +43,16 @@ Univariate models: vector_ar arima +Deep Learning models: + +.. autosummary:: + deep_ar + autoformer + etsformer + informer + transformer + + Note that the AutoML variants :py:mod:`AutoSarima ` and :py:mod:`AutoProphet ` @@ -53,6 +69,13 @@ forecast.base :undoc-members: :show-inheritance: +forecast.deep\_base +^^^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.deep_base + :members: + :undoc-members: + :show-inheritance: + forecast.sklearn\_base ^^^^^^^^^^^^^^^^^^^^^^ .. automodule:: merlion.models.forecast.sklearn_base @@ -99,13 +122,6 @@ forecast.smoother :undoc-members: :show-inheritance: -forecast.lstm -^^^^^^^^^^^^^ -.. automodule:: merlion.models.forecast.lstm - :members: - :undoc-members: - :show-inheritance: - Multivariate models ------------------- @@ -122,3 +138,43 @@ forecast.trees :members: :undoc-members: :show-inheritance: + +forecast.deep\_ar +^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.deep_ar + :members: + :undoc-members: + :show-inheritance: + +forecast.autoformer +^^^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.autoformer + :members: + :undoc-members: + :show-inheritance: + +forecast.etsformer +^^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.etsformer + :members: + :undoc-members: + :show-inheritance: + +forecast.informer +^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.informer + :members: + :undoc-members: + :show-inheritance: + +forecast.transformer +^^^^^^^^^^^^^^^^^^^^ +.. automodule:: merlion.models.forecast.transformer + :members: + :undoc-members: + :show-inheritance: + + + + + diff --git a/docs/source/merlion.models.rst b/docs/source/merlion.models.rst index d6dca10b5..429a5e12b 100644 --- a/docs/source/merlion.models.rst +++ b/docs/source/merlion.models.rst @@ -60,6 +60,7 @@ Finally, we support ensembles of models in :py:mod:`merlion.models.ensemble`. defaults factory base + deep_base layers anomaly anomaly.change_point @@ -111,6 +112,14 @@ base :undoc-members: :show-inheritance: +deep\_base +---------- + +.. automodule:: merlion.models.deep_base + :members: + :undoc-members: + :show-inheritance: + layers ------ @@ -118,3 +127,4 @@ layers :members: :undoc-members: :show-inheritance: + diff --git a/docs/source/merlion.models.utils.rst b/docs/source/merlion.models.utils.rst index da6dcfeee..f1f3f4838 100644 --- a/docs/source/merlion.models.utils.rst +++ b/docs/source/merlion.models.utils.rst @@ -7,10 +7,20 @@ utils :show-inheritance: .. autosummary:: + time_features rolling_window_dataset + early_stopping autosarima_utils +utils.time\_features +-------------------- +.. automodule:: merlion.models.utils.time_features + :members: + :undoc-members: + :show-inheritance: + + utils.rolling\_window\_dataset ------------------------------ @@ -20,11 +30,18 @@ utils.rolling\_window\_dataset :show-inheritance: +utils.early\_stopping +--------------------- +.. automodule:: merlion.models.utils.early_stopping + :members: + :undoc-members: + :show-inheritance: + + utils.autosarima\_utils ----------------------- .. automodule:: merlion.models.utils.autosarima_utils :members: :undoc-members: - :show-inheritance: - + :show-inheritance: \ No newline at end of file diff --git a/merlion/dashboard/models/anomaly.py b/merlion/dashboard/models/anomaly.py index 17aa853ac..5e7045719 100644 --- a/merlion/dashboard/models/anomaly.py +++ b/merlion/dashboard/models/anomaly.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -25,7 +25,6 @@ class AnomalyModel(ModelMixin, DataMixin): "DynamicBaseline", "IsolationForest", "ETSDetector", - "LSTMDetector", "MSESDetector", "ProphetDetector", "RandomCutForest", diff --git a/merlion/dashboard/models/forecast.py b/merlion/dashboard/models/forecast.py index a48153ff5..68401803b 100755 --- a/merlion/dashboard/models/forecast.py +++ b/merlion/dashboard/models/forecast.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -28,7 +28,6 @@ class ForecastModel(ModelMixin, DataMixin): "Prophet", "AutoProphet", "Sarima", - "LSTM", "VectorAR", "RandomForestForecaster", "ExtraTreesForecaster", diff --git a/merlion/models/anomaly/__init__.py b/merlion/models/anomaly/__init__.py index 1ed1758a3..c0d761f2e 100644 --- a/merlion/models/anomaly/__init__.py +++ b/merlion/models/anomaly/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -47,8 +47,8 @@ - trains the model on the time series ``train_data`` - ``anomaly_labels`` (optional): a time series aligned with ``train_data``, which indicates whether each time stamp is anomalous - - ``train_config`` (optional): extra configuration describing how the model should be trained (e.g. learning rate - for the `LSTMDetector`). Not used for all models. Class-level default provided for models which do use it. + - ``train_config`` (optional): extra configuration describing how the model should be trained. + Not used for all models. Class-level default provided for models which do use it. - ``post_rule_train_config``: extra configuration describing how to train the model's post-rule. Class-level default is provided for all models. - returns a time series of anomaly scores produced by the model on ``train_data``. diff --git a/merlion/models/anomaly/forecast_based/lstm.py b/merlion/models/anomaly/forecast_based/lstm.py deleted file mode 100644 index 0b2124bee..000000000 --- a/merlion/models/anomaly/forecast_based/lstm.py +++ /dev/null @@ -1,23 +0,0 @@ -# -# Copyright (c) 2021 salesforce.com, inc. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# -""" -Adaptation of a LSTM neural net forecaster, to the task of anomaly detection. -""" -from merlion.models.anomaly.forecast_based.base import ForecastingDetectorBase -from merlion.models.anomaly.base import DetectorConfig -from merlion.models.forecast.lstm import LSTMConfig, LSTMTrainConfig, LSTM -from merlion.post_process.threshold import AggregateAlarms - -# Note: we import LSTMTrainConfig just to get it into the namespace - - -class LSTMDetectorConfig(LSTMConfig, DetectorConfig): - _default_threshold = AggregateAlarms(alm_threshold=2.5) - - -class LSTMDetector(ForecastingDetectorBase, LSTM): - config_class = LSTMDetectorConfig diff --git a/merlion/models/deep_base.py b/merlion/models/deep_base.py new file mode 100644 index 000000000..51f777d98 --- /dev/null +++ b/merlion/models/deep_base.py @@ -0,0 +1,237 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Contains the base classes for all deep learning models. +""" +import io +import json +import copy +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from scipy.stats import norm +from abc import abstractmethod +from enum import Enum + +try: + import torch + import torch.nn as nn +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from merlion.models.base import Config, ModelBase +from merlion.plot import Figure +from merlion.transform.base import TransformBase, Identity +from merlion.transform.factory import TransformFactory +from merlion.utils.misc import initializer +from merlion.utils.time_series import to_pd_datetime, to_timestamp, TimeSeries, AggregationPolicy, MissingValuePolicy + +logger = logging.getLogger(__name__) + + +class Optimizer(Enum): + """ + Optimizers for learning model parameters. + """ + + Adam = torch.optim.Adam + AdamW = torch.optim.AdamW + SGD = torch.optim.SGD + Adagrad = torch.optim.Adagrad + RMSprop = torch.optim.RMSprop + + +class LossFunction(Enum): + """ + Loss functions for learning model parameters. + """ + + mse = nn.MSELoss + l1 = nn.L1Loss + huber = nn.HuberLoss + guassian_nll = nn.GaussianNLLLoss + + +class DeepConfig(Config): + """ + Config object used to define a deep learning (pytorch) model. + """ + + @initializer + def __init__( + self, + batch_size: int = 32, + num_epochs: int = 10, + optimizer: Union[str, Optimizer] = Optimizer.Adam, + loss_fn: Union[str, LossFunction] = LossFunction.mse, + clip_gradient: Optional[float] = None, + use_gpu: bool = False, + ts_encoding: Union[None, str] = "h", + lr: float = 1e-4, + weight_decay: float = 0.0, + valid_fraction: float = 0.2, + early_stop_patience: Union[None, int] = None, + **kwargs, + ): + """ + :param batch_size: Batch size of a batch for stochastic training of deep models + :param num_epochs: Total number of epochs for training. + :param optimizer: The optimizer for learning the parameters of the deep learning models. The value of optimizer + can be ``Adam``, ``AdamW``, ``SGD``, ``Adagrad``, ``RMSprop``. + :param loss_fn: Loss function for optimizing deep learning models. The value of loss_fn can be + ``mse`` for l2 loss, ``l1`` for l1 loss, ``huber`` for huber loss. + :param clip_gradient: Clipping gradient norm of model parameters before updating. If ``clip_gradient is None``, + then the gradient will not be clipped. + :param use_gpu: Whether to use gpu for training deep models. If ``use_gpu = True`` while thre is no GPU device, + the model will use CPU for training instead. + :param ts_encoding: whether the timestamp should be encoded to a float vector, which can be used + for training deep learning based time series models; if ``None``, the timestamp is not encoded. + If not ``None``, it represents the frequency for time features encoding options:[s:secondly, t:minutely, h:hourly, + d:daily, b:business days, w:weekly, m:monthly] + :param lr: Learning rate for optimizing deep learning models. + :param weight_decay: Weight decay (L2 penalty) (default: 0) + :param valid_fraction: Fraction of validation set to be split from training data + :param early_stop_patience: Number of epochs with no improvement after which training will be stopped for + early stopping function. If ``early_stop_patience = None``, the training process will not stop early. + """ + super().__init__(**kwargs) + + @property + def optimizer(self) -> Optimizer: + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer: Union[str, Optimizer]): + if isinstance(optimizer, str): + valid = set(Optimizer.__members__.keys()) + if optimizer not in valid: + raise KeyError(f"{optimizer} is not a valid optimizer that supported. Valid optimizers are: {valid}") + optimizer = Optimizer[optimizer] + self._optimizer = optimizer + + @property + def loss_fn(self) -> LossFunction: + return self._loss_fn + + @loss_fn.setter + def loss_fn(self, loss_fn: Union[str, LossFunction]): + if isinstance(loss_fn, str): + valid = set(LossFunction.__members__.keys()) + if loss_fn not in valid: + raise KeyError(f"{loss_fn} is not a valid loss that supported. Valid optimizers are: {valid}") + loss_fn = LossFunction[loss_fn] + self._loss_fn = loss_fn + + +class TorchModel(nn.Module): + """ + Abstract base class for Pytorch deep learning models + """ + + def __init__(self, config: DeepConfig): + super(TorchModel, self).__init__() + self.config = config + + @abstractmethod + def forward(self, past, past_timestamp, future_timestamp, *args, **kwargs): + raise NotImplementedError + + @property + def device(self): + return next(self.parameters()).device + + +class DeepModelBase(ModelBase): + """ + Abstract base class for all deep learning models + """ + + config_class = DeepConfig + deep_model_class = TorchModel + + def __init__(self, config: DeepConfig): + super().__init__(config) + self.deep_model = None + + def _create_model(self): + """ + Create and initialize deep models and neccessary components for training + """ + + self.deep_model = self.deep_model_class(self.config) + + self.optimizer = self.config.optimizer.value( + self.deep_model.parameters(), + lr=self.config.lr, + weight_decay=self.config.weight_decay, + ) + + self.loss_fn = self.config.loss_fn.value() + + if self.config.use_gpu: + self.to_gpu() + else: + self.to_cpu() + + @abstractmethod + def _get_batch_model_loss_and_outputs(self, batch): + """ + Calculate optimizing loss and get the output of the deep_model, given a batch of data + """ + raise NotImplementedError + + def to_gpu(self): + """ + Move deep model to GPU + """ + if torch.cuda.is_available(): + if self.deep_model is not None: + device = torch.device("cuda") + self.deep_model = self.deep_model.to(device) + else: + logger.warning("GPU not available, using CPU instead") + self.to_cpu() + + def to_cpu(self): + """ + Move deep model to CPU + """ + if self.deep_model is not None: + device = torch.device("cpu") + self.deep_model = self.deep_model.to(device) + + def __getstate__(self): + state = copy.copy(self.__dict__) + deep_model = state.pop("deep_model", None) + state.pop("optimizer", None) + state.pop("loss_fn", None) + state = copy.deepcopy(state) + + if deep_model is not None: + state["deep_model_state_dict"] = deep_model.state_dict() + + return state + + def __setstate__(self, state): + deep_model_state_dict = state.pop("deep_model_state_dict", None) + super().__setstate__(state) + + if deep_model_state_dict: + if self.deep_model is None: + self._create_model() + + buffer = io.BytesIO() + torch.save(deep_model_state_dict, buffer) + buffer.seek(0) + self.deep_model.load_state_dict(torch.load(buffer, map_location=self.deep_model.device)) diff --git a/merlion/models/factory.py b/merlion/models/factory.py index c9bea69e0..22db4a661 100644 --- a/merlion/models/factory.py +++ b/merlion/models/factory.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -28,7 +28,6 @@ IsolationForest="merlion.models.anomaly.isolation_forest:IsolationForest", # Forecast-based anomaly detection models ETSDetector="merlion.models.anomaly.forecast_based.ets:ETSDetector", - LSTMDetector="merlion.models.anomaly.forecast_based.lstm:LSTMDetector", MSESDetector="merlion.models.anomaly.forecast_based.mses:MSESDetector", ProphetDetector="merlion.models.anomaly.forecast_based.prophet:ProphetDetector", RandomCutForest="merlion.models.anomaly.random_cut_forest:RandomCutForest", @@ -47,7 +46,6 @@ # Forecasting models Arima="merlion.models.forecast.arima:Arima", ETS="merlion.models.forecast.ets:ETS", - LSTM="merlion.models.forecast.lstm:LSTM", MSES="merlion.models.forecast.smoother:MSES", Prophet="merlion.models.forecast.prophet:Prophet", Sarima="merlion.models.forecast.sarima:Sarima", @@ -56,6 +54,11 @@ RandomForestForecaster="merlion.models.forecast.trees:RandomForestForecaster", ExtraTreesForecaster="merlion.models.forecast.trees:ExtraTreesForecaster", LGBMForecaster="merlion.models.forecast.trees:LGBMForecaster", + TransformerForecaster="merlion.models.forecast.transformer:TransformerForecaster", + InformerForecaster="merlion.models.forecast.informer:InformerForecaster", + AutoformerForecaster="merlion.models.forecast.autoformer:AutoformerForecaster", + ETSformerForecaster="merlion.models.forecast.etsformer:ETSformerForecaster", + DeepARForecaster="merlion.models.forecast.deep_ar:DeepARForecaster", # Ensembles DetectorEnsemble="merlion.models.ensemble.anomaly:DetectorEnsemble", ForecasterEnsemble="merlion.models.ensemble.forecast:ForecasterEnsemble", diff --git a/merlion/models/forecast/__init__.py b/merlion/models/forecast/__init__.py index 782046bd3..14b1b79b2 100644 --- a/merlion/models/forecast/__init__.py +++ b/merlion/models/forecast/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -30,8 +30,8 @@ 3. ``model.train(train_data, train_config=None)`` - trains the model on the `TimeSeries` ``train_data`` - - ``train_config`` (optional): extra configuration describing how the model should be trained (e.g. learning rate - for `LSTM`). Not used for all models. Class-level default provided for models which do use it. + - ``train_config`` (optional): extra configuration describing how the model should be trained. + Not used for all models. Class-level default provided for models which do use it. - returns the model's prediction ``train_data``, in the same format as if you called `ForecasterBase.forecast` on the time stamps of ``train_data`` """ diff --git a/merlion/models/forecast/autoformer.py b/merlion/models/forecast/autoformer.py new file mode 100644 index 000000000..0469d79af --- /dev/null +++ b/merlion/models/forecast/autoformer.py @@ -0,0 +1,239 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementation of Autoformer. +""" +import copy +import logging +import math + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from typing import List, Optional, Tuple, Union +from abc import abstractmethod + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + + +from merlion.models.base import NormalizingConfig +from merlion.models.deep_base import TorchModel +from merlion.models.forecast.deep_base import DeepForecasterConfig, DeepForecaster + +from merlion.models.utils.nn_modules import ( + AutoCorrelation, + AutoCorrelationLayer, + SeriesDecomposeBlock, + SeasonalLayernorm, + DataEmbeddingWoPos, +) + +from merlion.models.utils.nn_modules.enc_dec_autoformer import Encoder, Decoder, EncoderLayer, DecoderLayer + + +from merlion.utils.misc import initializer + +logger = logging.getLogger(__name__) + + +class AutoformerConfig(DeepForecasterConfig, NormalizingConfig): + """ + Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting: https://arxiv.org/abs/2106.13008. + Code adapted from https://github.com/thuml/Autoformer. + """ + + @initializer + def __init__( + self, + n_past, + max_forecast_steps: int = None, + moving_avg: int = 25, + encoder_input_size: int = None, + decoder_input_size: int = None, + num_encoder_layers: int = 2, + num_decoder_layers: int = 1, + start_token_len: int = 0, + factor: int = 3, + model_dim: int = 512, + embed: str = "timeF", + dropout: float = 0.05, + activation: str = "gelu", + n_heads: int = 8, + fcn_dim: int = 2048, + **kwargs + ): + """ + :param n_past: # of past steps used for forecasting future. + :param max_forecast_steps: Max # of steps we would like to forecast for. + :param moving_avg: Window size of moving average for Autoformer. + :param encoder_input_size: Input size of encoder. If ``encoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param decoder_input_size: Input size of decoder. If ``decoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param num_encoder_layers: Number of encoder layers. + :param num_decoder_layers: Number of decoder layers. + :param start_token_len: Length of start token for deep transformer encoder-decoder based models. + The start token is similar to the special tokens for NLP models (e.g., bos, sep, eos tokens). + :param factor: Attention factor. + :param model_dim: Dimension of the model. + :param embed: Time feature encoding type, options include ``timeF``, ``fixed`` and ``learned``. + :param dropout: dropout rate. + :param activation: Activation function, can be ``gelu``, ``relu``, ``sigmoid``, etc. + :param n_heads: Number of heads of the model. + :param fcn_dim: Hidden dimension of the MLP layer in the model. + """ + + super().__init__(n_past=n_past, max_forecast_steps=max_forecast_steps, **kwargs) + + +class AutoformerModel(TorchModel): + """ + Implementaion of Autoformer deep torch model. + """ + + def __init__(self, config: AutoformerConfig): + super().__init__(config) + + if config.dim is not None: + config.encoder_input_size = config.dim if config.encoder_input_size is None else config.encoder_input_size + config.decoder_input_size = ( + config.encoder_input_size if config.decoder_input_size is None else config.decoder_input_size + ) + + config.c_out = config.encoder_input_size + + self.n_past = config.n_past + self.start_token_len = config.start_token_len + self.max_forecast_steps = config.max_forecast_steps + + kernel_size = config.moving_avg + self.decomp = SeriesDecomposeBlock(kernel_size) + + # Embedding + # The series-wise connection inherently contains the sequential information. + # Thus, we can discard the position embedding of transformers. + self.enc_embedding = DataEmbeddingWoPos( + config.encoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + self.dec_embedding = DataEmbeddingWoPos( + config.decoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AutoCorrelationLayer( + AutoCorrelation(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.fcn_dim, + moving_avg=config.moving_avg, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_encoder_layers) + ], + norm_layer=SeasonalLayernorm(config.model_dim), + ) + + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AutoCorrelationLayer( + AutoCorrelation(True, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + AutoCorrelationLayer( + AutoCorrelation(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.c_out, + config.fcn_dim, + moving_avg=config.moving_avg, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_decoder_layers) + ], + norm_layer=SeasonalLayernorm(config.model_dim), + projection=nn.Linear(config.model_dim, config.c_out, bias=True), + ) + + def forward( + self, + past, + past_timestamp, + future_timestamp, + enc_self_mask=None, + dec_self_mask=None, + dec_enc_mask=None, + **kwargs + ): + config = self.config + + future_timestamp = torch.cat( + [past_timestamp[:, (past_timestamp.shape[1] - self.start_token_len) :], future_timestamp], dim=1 + ) + + # decomp init + mean = torch.mean(past, dim=1).unsqueeze(1).repeat(1, self.max_forecast_steps, 1) + zeros = torch.zeros( + [past.shape[0], self.max_forecast_steps, past.shape[2]], dtype=torch.float, device=self.device + ) + seasonal_init, trend_init = self.decomp(past) + # decoder input + trend_init = torch.cat([trend_init[:, (trend_init.shape[1] - self.start_token_len) :, :], mean], dim=1) + seasonal_init = torch.cat( + [seasonal_init[:, (seasonal_init.shape[1] - self.start_token_len) :, :], zeros], dim=1 + ) + + # enc + enc_out = self.enc_embedding(past, past_timestamp) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + # dec + dec_out = self.dec_embedding(seasonal_init, future_timestamp) + seasonal_part, trend_part = self.decoder( + dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, trend=trend_init + ) + + # final + dec_out = trend_part + seasonal_part + + if self.config.target_seq_index is not None: + return dec_out[:, -self.max_forecast_steps :, :1] + else: + return dec_out[:, -self.max_forecast_steps :, :] # [B, L, D] + + +class AutoformerForecaster(DeepForecaster): + """ + Implementaion of Autoformer deep forecaster. + """ + + config_class = AutoformerConfig + deep_model_class = AutoformerModel + + def __init__(self, config: AutoformerConfig): + super().__init__(config) diff --git a/merlion/models/forecast/base.py b/merlion/models/forecast/base.py index 4c4248434..ceae77a43 100644 --- a/merlion/models/forecast/base.py +++ b/merlion/models/forecast/base.py @@ -106,6 +106,13 @@ def require_univariate(self) -> bool: """ return False + @property + def support_multivariate_output(self) -> bool: + """ + Indicating whether the forecasting model can forecast multivariate output. + """ + return False + def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_prev: TimeSeries = None): assert self.timedelta is not None and self.last_train_time is not None, ( "train() must be called before you can call forecast(). " @@ -157,19 +164,27 @@ def train_pre_process( self, train_data: TimeSeries, exog_data: TimeSeries = None, return_exog=None ) -> Union[TimeSeries, Tuple[TimeSeries, Union[TimeSeries, None]]]: train_data = super().train_pre_process(train_data) + if self.dim == 1: self.config.target_seq_index = 0 - elif self.target_seq_index is None: + elif self.target_seq_index is None and not self.support_multivariate_output: raise RuntimeError( - f"Attempting to use a forecaster on a {train_data.dim}-variable " + f"Attempting to use a forecaster that does not support multivariate outputs " + f"on a {train_data.dim}-variable " f"time series, but didn't specify a `target_seq_index` " f"indicating which univariate is the target." ) - assert 0 <= self.target_seq_index < train_data.dim, ( - f"Expected `target_seq_index` to be between 0 and {train_data.dim} " - f"(the dimension of the transformed data), but got {self.target_seq_index}" + + assert self.support_multivariate_output or (0 <= self.target_seq_index < train_data.dim), ( + f"Expected `support_multivariate_output = True`," + f"or `target_seq_index` to be between 0 and {train_data.dim}" + f"(the dimension of the transformed data), but got {self.target_seq_index} " ) - self.target_name = train_data.names[self.target_seq_index] + + if self.support_multivariate_output and self.target_seq_index is None: + self.target_name = str(train_data.names) + else: + self.target_name = train_data.names[self.target_seq_index] # Handle exogenous data if return_exog is None: @@ -315,7 +330,7 @@ def forecast( # Format the return values and reset the transform's inversion state if self.invert_transform and time_series_prev is None: time_series_prev = self.transform(self.train_data) - if time_series_prev is not None: + if time_series_prev is not None and self.target_seq_index is not None: time_series_prev = pd.DataFrame(time_series_prev.univariates[time_series_prev.names[self.target_seq_index]]) ret = self._process_forecast(forecast, err, time_series_prev, return_prev=return_prev, return_iqr=return_iqr) self.transform.inversion_state = old_inversion_state diff --git a/merlion/models/forecast/deep_ar.py b/merlion/models/forecast/deep_ar.py new file mode 100644 index 000000000..dacbc24b8 --- /dev/null +++ b/merlion/models/forecast/deep_ar.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementation of Deep AR +""" +import copy +import logging +import math + +import numpy as np +import pandas as pd + +from typing import List, Optional, Tuple, Union + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from merlion.utils.misc import initializer + +from merlion.models.base import NormalizingConfig +from merlion.models.deep_base import TorchModel, LossFunction +from merlion.models.forecast.deep_base import DeepForecasterConfig, DeepForecaster + + +logger = logging.getLogger(__name__) + + +class DeepARConfig(DeepForecasterConfig, NormalizingConfig): + """ + DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks: https://arxiv.org/abs/1704.04110 + """ + + @initializer + def __init__( + self, + n_past, + max_forecast_steps: int = None, + hidden_size: Union[int, None] = 32, + num_hidden_layers: int = 2, + lags_seq: List[int] = [1], + num_prediction_samples: int = 10, + loss_fn: Union[str, LossFunction] = LossFunction.guassian_nll, + **kwargs, + ): + """ + :param n_past: # of past steps used for forecasting future. + :param max_forecast_steps: Max # of steps we would like to forecast for. + :param hidden_size: hidden_size of the LSTM layers + :param num_hidden_layers: # of hidden layers in LSTM + :param lags_seq: Indices of the lagged observations that the RNN takes as input. For example, + ``[1]`` indicates that the RNN only takes the observation at time ``t-1`` to produce the + output for time ``t``. + :param num_prediction_samples: # of samples to produce the forecasting + """ + + super().__init__(n_past=n_past, max_forecast_steps=max_forecast_steps, loss_fn=loss_fn, **kwargs) + + +class DeepARModel(TorchModel): + """ + Implementaion of Deep AR model + """ + + def __init__(self, config: DeepARConfig): + super().__init__(config) + + assert len(config.lags_seq) > 0, "lags_seq must not be empty!" + self.lags_seq = config.lags_seq + self.n_past = config.n_past + self.n_context = config.n_past - max(self.lags_seq) + self.max_forecast_steps = config.max_forecast_steps + self.output_size = config.dim + + freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + + input_size = len(self.lags_seq) * self.output_size + freq_map[config.ts_encoding] + + # for decoding the lags are shifted by one, at the first time-step + # of the decoder a lag of one corresponds to the last target value + self.shifted_lags = [l - 1 for l in self.lags_seq] + self.num_prediction_samples = config.num_prediction_samples + + if config.hidden_size is None: + hidden_size = int(4 * (1 + math.pow(math.log(config.dim), 4))) + else: + hidden_size = config.hidden_size + + self.rnn = nn.LSTM( + input_size, + hidden_size=hidden_size, + num_layers=config.num_hidden_layers, + batch_first=True, + ) + + self.distr_proj = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.output_size * 2), + ) + + self.loss_fn = self.config.loss_fn.value() + + @staticmethod + def get_lagged_subsequences( + sequence, + sequence_length, + indices: List[int], + subsequences_length: int = 1, + ) -> torch.Tensor: + assert max(indices) + subsequences_length <= sequence_length, ( + f"lags cannot go further than n_past, found lag {max(indices)} " f"while n_past is only {sequence_length}" + ) + assert all(lag_index >= 0 for lag_index in indices) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def unroll_encoder(self, past, past_timestamp, future_timestamp, future=None): + if future_timestamp is None or future is None: + time_features = past_timestamp[:, (self.n_past - self.n_context) :, :] + sequence = past + sequence_length = self.n_past + subsequences_length = self.n_context + else: + time_features = torch.cat((past_timestamp[:, (self.n_past - self.n_context) :, :], future_timestamp), dim=1) + sequence = torch.cat((past, future), dim=1) + sequence_length = self.n_past + self.max_forecast_steps + subsequences_length = self.n_context + self.max_forecast_steps + + lags = self.get_lagged_subsequences( + sequence=sequence, + sequence_length=sequence_length, + indices=self.lags_seq, + subsequences_length=subsequences_length, + ) + + input_lags = lags.reshape((-1, subsequences_length, len(self.lags_seq) * self.output_size)) + rnn_inputs = torch.cat((input_lags, time_features), dim=-1) + outputs, states = self.rnn(rnn_inputs) + + return outputs, states + + def calculate_loss(self, past, past_timestamp, future, future_timestamp): + rnn_outputs, _ = self.unroll_encoder(past, past_timestamp, future_timestamp, future) + distr_proj_out = self.distr_proj(rnn_outputs) + + mu, log_sigma = torch.split(distr_proj_out, self.config.dim, dim=-1) + sigma = torch.log(1 + torch.exp(log_sigma)) + 1e-07 + + target_future = torch.cat((past[:, (self.n_past - self.n_context) :, :], future), dim=1) + + loss = self.loss_fn(mu, target_future, torch.square(sigma)) + + return loss + + def sampling_decoder(self, past, time_features, begin_states): + repeated_past = past.repeat_interleave( + repeats=self.num_prediction_samples, + dim=0, + ) + + repeated_time_features = time_features.repeat_interleave(repeats=self.num_prediction_samples, dim=0) + + repeated_states = [s.repeat_interleave(repeats=self.num_prediction_samples, dim=1) for s in begin_states] + + future_samples = [] + + for k in range(self.max_forecast_steps): + lags = self.get_lagged_subsequences( + sequence=repeated_past, + sequence_length=self.n_past + k, + indices=self.shifted_lags, + subsequences_length=1, + ) + input_lags = lags.reshape(-1, 1, len(self.lags_seq) * self.output_size) + + decoder_input = torch.cat((input_lags, repeated_time_features[:, k : k + 1, :]), dim=-1) + rnn_outputs, repeated_states = self.rnn(decoder_input, repeated_states) + + distr_proj_out = self.distr_proj(rnn_outputs) + + mu, log_sigma = torch.split(distr_proj_out, self.config.dim, dim=-1) + sigma = torch.log(1 + torch.exp(log_sigma)) + 1e-07 + + new_samples = mu + torch.randn_like(mu, dtype=torch.float, device=self.device) * sigma + + repeated_past = torch.cat((repeated_past, new_samples), dim=1) + future_samples.append(new_samples) + + samples = torch.cat(future_samples, dim=1) + + return samples.reshape((-1, self.num_prediction_samples, self.max_forecast_steps, self.output_size)) + + def forward(self, past, past_timestamp, future_timestamp, mean_samples=True): + _, states = self.unroll_encoder(past, past_timestamp, future_timestamp) + + forecast_samples = self.sampling_decoder( + past=past, + time_features=future_timestamp, + begin_states=states, + ) + + target_idx = self.config.target_seq_index + if target_idx is not None: + forecast_samples = forecast_samples[:, :, :, target_idx : target_idx + 1] + + return forecast_samples.mean(dim=1) if mean_samples else forecast_samples + + +class DeepARForecaster(DeepForecaster): + """ + Implementaion of Deep AR model forecaster + """ + + config_class = DeepARConfig + deep_model_class = DeepARModel + + def __init__(self, config: DeepARConfig): + super().__init__(config) + + def _get_batch_model_loss_and_outputs(self, batch): + config = self.config + past, past_timestamp, future, future_timestamp = batch + + model_output = self.deep_model(past, past_timestamp, future_timestamp) + + if future is None: + return None, model_output, None + + # Calcuating the loss with maximum likelihood, + # which is seperate from the sampling procedure of deep AR models + loss = self.deep_model.calculate_loss(past, past_timestamp, future, future_timestamp) + + if self.target_seq_index is not None: + future = future[:, :, self.target_seq_index : self.target_seq_index + 1] + + return loss, model_output, future diff --git a/merlion/models/forecast/deep_base.py b/merlion/models/forecast/deep_base.py new file mode 100644 index 000000000..37f64e215 --- /dev/null +++ b/merlion/models/forecast/deep_base.py @@ -0,0 +1,241 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Base class for Deep Learning Forecasting Models +""" +import copy + +import logging +import numpy as np +import pandas as pd + + +from scipy.stats import norm +from typing import List, Optional, Tuple, Union + +try: + import torch + import torch.nn as nn +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + + +from merlion.models.deep_base import DeepConfig, DeepModelBase +from merlion.models.forecast.base import ForecasterBase, ForecasterConfig +from merlion.models.utils.rolling_window_dataset import RollingWindowDataset +from merlion.models.utils.time_features import get_time_features +from merlion.models.utils.early_stopping import EarlyStopping + +from merlion.transform.base import TransformBase, Identity +from merlion.transform.factory import TransformFactory +from merlion.utils.misc import initializer, ProgressBar +from merlion.utils.time_series import to_pd_datetime, to_timestamp, TimeSeries, AggregationPolicy, MissingValuePolicy + + +logger = logging.getLogger(__name__) + + +class DeepForecasterConfig(DeepConfig, ForecasterConfig): + """ + Config object used to define a forecaster with deep model + """ + + def __init__( + self, + n_past: int, + **kwargs, + ): + """ + :param n_past: # of past steps used for forecasting future. + """ + super().__init__( + **kwargs, + ) + self.n_past = n_past + + +class DeepForecaster(DeepModelBase, ForecasterBase): + """ + Base class for a deep forecaster model + """ + + config_class = DeepForecasterConfig + + def __init__(self, config: DeepForecasterConfig): + super().__init__(config) + + def _get_np_loss_and_prediction(self, eval_dataset: RollingWindowDataset): + """ + Get numpy prediction and loss with evaluation mode for a given dataset or data + + :param eval_dataset: Evaluation dataset + + :return: The numpy prediction of the model and the average loss for the given dataset. + """ + self.deep_model.eval() + all_preds = [] + total_loss = [] + for i, batch in enumerate(eval_dataset): + with torch.no_grad(): + loss, outputs, y_true = self._get_batch_model_loss_and_outputs(self._convert_batch_to_tensors(batch)) + pred = outputs.detach().cpu().numpy() + all_preds.append(pred) + total_loss.append(loss.item()) + + preds = np.concatenate(all_preds, axis=0) + return preds, np.average(total_loss) + + @property + def support_multivariate_output(self) -> bool: + """ + Deep models support multivariate output by default. + """ + return True + + def _convert_batch_to_tensors(self, batch): + device = self.deep_model.device + + past, past_timestamp, future, future_timestamp = batch + + past = torch.tensor(past, dtype=torch.float, device=device) + future = future if future is None else torch.tensor(future, dtype=torch.float, device=device) + + past_timestamp = torch.tensor(past_timestamp, dtype=torch.float, device=device) + future_timestamp = torch.tensor(future_timestamp, dtype=torch.float, device=device) + + return past, past_timestamp, future, future_timestamp + + def _train(self, train_data: pd.DataFrame, train_config=None) -> pd.DataFrame: + config = self.config + + # creating model before the training + self._create_model() + + total_dataset = RollingWindowDataset( + train_data, + n_past=config.n_past, + n_future=config.max_forecast_steps, + batch_size=config.batch_size, + target_seq_index=None, # have to set None, we use target_seq_index later in the training, if not this is a bug + ts_encoding=config.ts_encoding, + valid_fraction=config.valid_fraction, + flatten=False, + shuffle=True, + validation=False, + ) + + train_steps = len(total_dataset) + logger.info(f"Training steps each epoch: {train_steps}") + + bar = ProgressBar(total=config.num_epochs) + early_stopping = EarlyStopping(patience=config.early_stop_patience) if config.early_stop_patience else None + + # start training + for epoch in range(config.num_epochs): + train_loss = [] + self.deep_model.train() + total_dataset.seed = epoch + 1 + + for i, batch in enumerate(total_dataset): + self.optimizer.zero_grad() + + loss, _, _ = self._get_batch_model_loss_and_outputs(self._convert_batch_to_tensors(batch)) + train_loss.append(loss.item()) + + loss.backward() + if config.clip_gradient is not None: + torch.nn.utils.clip_grad_norm(self.model.parameters(), config.clip_gradient) + + self.optimizer.step() + + train_loss = np.average(train_loss) + + # set validation flag + total_dataset.validation = True + _, val_loss = self._get_np_loss_and_prediction(total_dataset) + total_dataset.validation = False + + if bar is not None: + bar.print( + epoch + 1, prefix="", suffix=f"Train Loss: {train_loss: .4f}, Validation Loss: {val_loss: .4f}" + ) + + if early_stopping is not None: + early_stopping(val_loss, self.deep_model) + if early_stopping.early_stop: + logger.info(f"Early stopping with {config.early_stop_patience} patience") + break + + if early_stopping is not None: + early_stopping.load_best_model(self.deep_model) + logger.info(f"Load the best model with validation loss: {early_stopping.val_loss_min: .4f}") + + logger.info("End of the training loop") + + # get predictions + total_dataset.shuffle = False + total_dataset.validation = None + pred, _ = self._get_np_loss_and_prediction(total_dataset) + + # since the model predicts multiple steps, we concatenate all the first steps together + columns = train_data.columns if self.target_seq_index is None else [self.target_name] + column_index = train_data.index[config.n_past : (len(train_data) - config.max_forecast_steps + 1)] + return pd.DataFrame(pred[:, 0], index=column_index, columns=columns), None + + def _get_batch_model_loss_and_outputs(self, batch): + """ + For loss calculation and output prediction + + :param batch: a batch contains `(past, past_timestamp, future, future_timestamp)` used for calculating loss and model outputs + + :return: calculated loss, deep model outputs and targeted ground truth future + """ + past, past_timestamp, future, future_timestamp = batch + model_output = self.deep_model(past, past_timestamp, future_timestamp) + + if future is None: + return None, model_output, None + + if self.target_seq_index is not None: + future = future[:, :, self.target_seq_index : self.target_seq_index + 1] + + loss = self.loss_fn(model_output, future) + return loss, model_output, future + + @property + def require_even_sampling(self) -> bool: + return False + + def _forecast( + self, time_stamps: List[int], time_series_prev: pd.DataFrame = None, return_prev=False + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + + if time_series_prev is None: + time_series_prev = self.transform(self.train_data).to_pd().iloc[-self.config.n_past :] + + # convert to vector feature + prev_timestamp = get_time_features(time_series_prev.index, self.config.ts_encoding) + future_timestamp = get_time_features(to_pd_datetime(time_stamps), self.config.ts_encoding) + + # preparing data + past = np.expand_dims(time_series_prev.values, 0) + past_timestamp = np.expand_dims(prev_timestamp, 0) + future_timestamp = np.expand_dims(future_timestamp, 0) + + self.deep_model.eval() + batch = (past, past_timestamp, None, future_timestamp) + _, model_output, _ = self._get_batch_model_loss_and_outputs(self._convert_batch_to_tensors(batch)) + + preds = model_output.detach().cpu().numpy().squeeze() + columns = time_series_prev.columns if self.target_seq_index is None else [self.target_name] + pd_pred = pd.DataFrame(preds, index=to_pd_datetime(time_stamps), columns=columns) + + return pd_pred, None diff --git a/merlion/models/forecast/ets.py b/merlion/models/forecast/ets.py index 1b6d4ecdf..4ecad2438 100644 --- a/merlion/models/forecast/ets.py +++ b/merlion/models/forecast/ets.py @@ -137,7 +137,9 @@ def _train(self, train_data: pd.DataFrame, train_config=None): name = self.target_name train_data = train_data[name] times = train_data.index - self.model = self._instantiate_model(pd.Series(train_data.values)).fit(disp=False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.model = self._instantiate_model(pd.Series(train_data.values)).fit(disp=False) # get forecast for the training data self._last_val = train_data[-1] @@ -166,10 +168,12 @@ def _forecast( # the default setting of refit=False is fast and conducts exponential smoothing with given parameters, # while the setting of refit=True is slow and refits the model on time_series_prev. model = self._instantiate_model(val_prev) - if self.config.refit and len(time_series_prev) > self._max_lookback: - model = model.fit(start_params=self.model.params, disp=False) - else: - model = model.smooth(params=self.model.params) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if self.config.refit and len(time_series_prev) > self._max_lookback: + model = model.fit(start_params=self.model.params, disp=False) + else: + model = model.smooth(params=self.model.params) # Run forecasting. forecast_result = model.get_prediction(start=start, end=start + len(time_stamps) - 1) diff --git a/merlion/models/forecast/etsformer.py b/merlion/models/forecast/etsformer.py new file mode 100644 index 000000000..443031668 --- /dev/null +++ b/merlion/models/forecast/etsformer.py @@ -0,0 +1,204 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementation of ETSformer. +""" +import copy +import logging +import math + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from typing import List, Optional, Tuple, Union +from abc import abstractmethod + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + + +from merlion.models.base import NormalizingConfig +from merlion.models.deep_base import TorchModel +from merlion.models.forecast.deep_base import DeepForecasterConfig, DeepForecaster + +from merlion.models.utils.nn_modules import ETSEmbedding +from merlion.models.utils.nn_modules.enc_dec_etsformer import EncoderLayer, Encoder, DecoderLayer, Decoder + +from merlion.utils.misc import initializer + +logger = logging.getLogger(__name__) + + +class ETSformerConfig(DeepForecasterConfig, NormalizingConfig): + """ + ETSformer: Exponential Smoothing Transformers for Time-series Forecasting: https://arxiv.org/abs/2202.01381 + Code adapted from https://github.com/salesforce/ETSformer. + """ + + @initializer + def __init__( + self, + n_past, + max_forecast_steps: int = None, + encoder_input_size: int = None, + decoder_input_size: int = None, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + model_dim: int = 512, + dropout: float = 0.2, + n_heads: int = 8, + fcn_dim: int = 2048, + top_K: int = 1, # Top-K Fourier bases + sigma=0.2, + **kwargs + ): + """ + :param n_past: # of past steps used for forecasting future. + :param max_forecast_steps: Max # of steps we would like to forecast for. + :param encoder_input_size: Input size of encoder. If ``encoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param decoder_input_size: Input size of decoder. If ``decoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param num_encoder_layers: Number of encoder layers. + :param num_decoder_layers: Number of decoder layers. + :param model_dim: Dimension of the model. + :param dropout: dropout rate. + :param n_heads: Number of heads of the model. + :param fcn_dim: Hidden dimension of the MLP layer in the model. + :param top_K: Top-K Frequent Fourier basis. + :param sigma: Standard derivation for ETS input data transform. + """ + super().__init__(n_past=n_past, max_forecast_steps=max_forecast_steps, **kwargs) + + +class ETSformerModel(TorchModel): + """ + Implementaion of ETSformer deep torch model. + """ + + def __init__(self, config: ETSformerConfig): + super().__init__(config) + + assert ( + config.num_encoder_layers == config.num_decoder_layers + ), "The number of encoder and decoder layers must be equal!" + if config.dim is not None: + config.encoder_input_size = config.dim if config.encoder_input_size is None else config.encoder_input_size + config.decoder_input_size = ( + config.encoder_input_size if config.decoder_input_size is None else config.decoder_input_size + ) + + config.c_out = config.encoder_input_size + + self.n_past = config.n_past + self.max_forecast_steps = config.max_forecast_steps + + self.enc_embedding = ETSEmbedding(config.encoder_input_size, config.model_dim, dropout=config.dropout) + + self.encoder = Encoder( + [ + EncoderLayer( + config.model_dim, + config.n_heads, + config.c_out, + config.n_past, + config.max_forecast_steps, + config.top_K, + dim_feedforward=config.fcn_dim, + dropout=config.dropout, + output_attention=False, + ) + for _ in range(config.num_encoder_layers) + ] + ) + + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + config.model_dim, + config.n_heads, + config.c_out, + config.max_forecast_steps, + dropout=config.dropout, + output_attention=False, + ) + for _ in range(config.num_decoder_layers) + ], + ) + + def forward( + self, + past, + past_timestamp, + future_timestamp, + enc_self_mask=None, + dec_self_mask=None, + dec_enc_mask=None, + attention=False, + **kwargs + ): + with torch.no_grad(): + if self.training: + past = self.transform(past) + res = self.enc_embedding(past) + level, growths, seasons, season_attns, growth_attns = self.encoder(res, past, attn_mask=enc_self_mask) + + growth, season, growth_dampings = self.decoder(growths, seasons) + + preds = level[:, -1:] + growth + season + + # maybe remove later + if attention: + decoder_growth_attns = [] + for growth_attn, growth_damping in zip(growth_attns, growth_dampings): + decoder_growth_attns.append(torch.einsum("bth,oh->bhot", [growth_attn.squeeze(-1), growth_damping])) + + season_attns = torch.stack(season_attns, dim=0)[:, :, -self.pred_len :] + season_attns = reduce(season_attns, "l b d o t -> b o t", reduction="mean") + decoder_growth_attns = torch.stack(decoder_growth_attns, dim=0)[:, :, -self.pred_len :] + decoder_growth_attns = reduce(decoder_growth_attns, "l b d o t -> b o t", reduction="mean") + return preds, season_attns, decoder_growth_attns + + if self.config.target_seq_index is not None: + return preds[:, :, :1] + else: + return preds + + @torch.no_grad() + def transform(self, x): + return self.jitter(self.shift(self.scale(x))) + + def jitter(self, x): + return x + (torch.randn(x.shape).to(x.device) * self.config.sigma) + + def scale(self, x): + return x * (torch.randn(x.size(-1)).to(x.device) * self.config.sigma + 1) + + def shift(self, x): + return x + (torch.randn(x.size(-1)).to(x.device) * self.config.sigma) + + +class ETSformerForecaster(DeepForecaster): + """ + Implementaion of ETSformer deep forecaster. + """ + + config_class = ETSformerConfig + deep_model_class = ETSformerModel + + def __init__(self, config: ETSformerConfig): + super().__init__(config) diff --git a/merlion/models/forecast/informer.py b/merlion/models/forecast/informer.py new file mode 100644 index 000000000..ce0f0b0e4 --- /dev/null +++ b/merlion/models/forecast/informer.py @@ -0,0 +1,218 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementation of Informer. +""" +import copy +import logging +import math + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from typing import List, Optional, Tuple, Union +from abc import abstractmethod + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from merlion.utils.misc import initializer + +from merlion.models.base import NormalizingConfig +from merlion.models.deep_base import TorchModel +from merlion.models.forecast.deep_base import DeepForecasterConfig, DeepForecaster + +from merlion.models.utils.nn_modules import ProbAttention, AttentionLayer, DataEmbedding, ConvLayer +from merlion.models.utils.nn_modules.enc_dec_transformer import ( + Decoder, + DecoderLayer, + Encoder, + EncoderLayer, +) + + +logger = logging.getLogger(__name__) + + +class InformerConfig(DeepForecasterConfig, NormalizingConfig): + """ + Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting: https://arxiv.org/abs/2012.07436 + Code adapted from https://github.com/thuml/Autoformer. + """ + + @initializer + def __init__( + self, + n_past, + max_forecast_steps: int = None, + encoder_input_size: int = None, + decoder_input_size: int = None, + num_encoder_layers: int = 2, + num_decoder_layers: int = 1, + start_token_len: int = 0, + factor: int = 3, + model_dim: int = 512, + embed: str = "timeF", + dropout: float = 0.05, + activation: str = "gelu", + n_heads: int = 8, + fcn_dim: int = 2048, + distil: bool = True, + **kwargs + ): + """ + :param n_past: # of past steps used for forecasting future. + :param max_forecast_steps: Max # of steps we would like to forecast for. + :param encoder_input_size: Input size of encoder. If ``encoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param decoder_input_size: Input size of decoder. If ``decoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param num_encoder_layers: Number of encoder layers. + :param num_decoder_layers: Number of decoder layers. + :param start_token_len: Length of start token for deep transformer encoder-decoder based models. + The start token is similar to the special tokens for NLP models (e.g., bos, sep, eos tokens). + :param factor: Attention factor. + :param model_dim: Dimension of the model. + :param embed: Time feature encoding type, options include ``timeF``, ``fixed`` and ``learned``. + :param dropout: dropout rate. + :param activation: Activation function, can be ``gelu``, ``relu``, ``sigmoid``, etc. + :param n_heads: Number of heads of the model. + :param fcn_dim: Hidden dimension of the MLP layer in the model. + :param distil: whether to use distilling in the encoder of the model. + """ + + super().__init__(n_past=n_past, max_forecast_steps=max_forecast_steps, **kwargs) + + +class InformerModel(TorchModel): + """ + Implementaion of informer deep torch model. + """ + + def __init__(self, config: InformerConfig): + super().__init__(config) + + if config.dim is not None: + config.encoder_input_size = config.dim if config.encoder_input_size is None else config.encoder_input_size + config.decoder_input_size = ( + config.encoder_input_size if config.decoder_input_size is None else config.decoder_input_size + ) + + config.c_out = config.encoder_input_size + + self.n_past = config.n_past + self.start_token_len = config.start_token_len + self.max_forecast_steps = config.max_forecast_steps + + self.enc_embedding = DataEmbedding( + config.encoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + self.dec_embedding = DataEmbedding( + config.decoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + ProbAttention(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.fcn_dim, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_encoder_layers) + ], + [ConvLayer(config.model_dim) for l in range(config.num_encoder_layers - 1)] if config.distil else None, + norm_layer=torch.nn.LayerNorm(config.model_dim), + ) + + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + ProbAttention(True, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + AttentionLayer( + ProbAttention(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.fcn_dim, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_decoder_layers) + ], + norm_layer=torch.nn.LayerNorm(config.model_dim), + projection=nn.Linear(config.model_dim, config.c_out, bias=True), + ) + + self.config = config + + def forward( + self, + past, + past_timestamp, + future_timestamp, + enc_self_mask=None, + dec_self_mask=None, + dec_enc_mask=None, + **kwargs + ): + config = self.config + + start_token = past[:, past.shape[1] - self.start_token_len :] + dec_inp = torch.zeros( + past.shape[0], self.max_forecast_steps, config.decoder_input_size, dtype=torch.float, device=self.device + ) + dec_inp = torch.cat([start_token, dec_inp], dim=1) + + future_timestamp = torch.cat( + [past_timestamp[:, (past_timestamp.shape[1] - self.start_token_len) :], future_timestamp], dim=1 + ) + + enc_out = self.enc_embedding(past, past_timestamp) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(dec_inp, future_timestamp) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + + if self.config.target_seq_index is not None: + return dec_out[:, -self.max_forecast_steps :, :1] + else: + return dec_out[:, -self.max_forecast_steps :, :] + + +class InformerForecaster(DeepForecaster): + """ + Implementaion of Informer deep forecaster. + """ + + config_class = InformerConfig + deep_model_class = InformerModel + + def __init__(self, config: InformerConfig): + super().__init__(config) diff --git a/merlion/models/forecast/lstm.py b/merlion/models/forecast/lstm.py deleted file mode 100644 index d2e5887f9..000000000 --- a/merlion/models/forecast/lstm.py +++ /dev/null @@ -1,413 +0,0 @@ -# -# Copyright (c) 2022 salesforce.com, inc. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# -""" -A forecaster based on a LSTM neural net. -""" -try: - import torch - import torch.nn as nn - import torch.nn.functional as F - from torch.utils.data import DataLoader, Dataset -except ImportError as e: - err = ( - "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " - "`pip install `salesforce-merlion[all]`" - ) - raise ImportError(str(e) + ". " + err) - -import bisect -import datetime -import logging -import os -from typing import List, Tuple, Union - -import numpy as np -import pandas as pd -from tqdm import tqdm - -from merlion.models.forecast.base import ForecasterConfig, ForecasterBase -from merlion.transform.normalize import MeanVarNormalize -from merlion.transform.moving_average import DifferenceTransform -from merlion.transform.resample import TemporalResample -from merlion.transform.sequence import TransformSequence -from merlion.utils.time_series import to_pd_datetime, to_timestamp, UnivariateTimeSeries - -logger = logging.getLogger(__name__) - - -class LSTMConfig(ForecasterConfig): - """ - Configuration class for `LSTM`. - """ - - _default_transform = TransformSequence( - [ - TemporalResample(granularity=None, trainable_granularity=True), - DifferenceTransform(), - MeanVarNormalize(normalize_bias=True, normalize_scale=True), - ] - ) - - def __init__(self, max_forecast_steps: int, nhid=1024, model_strides=(1,), **kwargs): - """ - :param nhid: hidden dimension of LSTM - :param model_strides: tuple indicating the stride(s) at which we would - like to subsample the input data before giving it to the model. - """ - self.model_strides = list(model_strides) - self.nhid = nhid - super().__init__(max_forecast_steps=max_forecast_steps, **kwargs) - - -class LSTMTrainConfig(object): - """ - LSTM training configuration. - """ - - def __init__( - self, - lr=1e-5, - batch_size=128, - epochs=128, - seq_len=256, - data_stride=1, - valid_split=0.2, - checkpoint_file="checkpoint.pt", - ): - assert 0 < valid_split < 1 - self.lr = lr - self.batch_size = batch_size # 8 - self.epochs = epochs - self.seq_len = seq_len - self.data_stride = data_stride - self.checkpoint_file = checkpoint_file - self.valid_split = valid_split - - -class Corpus(Dataset): - """ - Build a torch corpus from an input sequence - - :meta private: - """ - - def __init__(self, sequence, seq_len=32, stride=1): - """ - :param sequence: a list of items - :param seq_len: the sequence length used in the LSTM models - :param stride: stride if you want to subsample the sequence up front - """ - super().__init__() - self.seq_len = seq_len - self.stride = stride - self.sequence = sequence - if len(self) == 0: - raise RuntimeError( - f"Zero length dataset! This typically occurs when " - f"seq_len > len(sequence). Here seq_len={seq_len}, " - f"len(sequence)={len(sequence)}." - ) - logger.info(f"Dataset length: {len(self)}") - - def __len__(self): - n = len(self.sequence) - (self.seq_len - 1) * self.stride - return max(0, n) - - def __getitem__(self, idx): - max_idx = idx + (self.seq_len - 1) * self.stride + 1 - return torch.tensor(self.sequence[idx : max_idx : self.stride], dtype=torch.float) - - -class _LSTMBase(nn.Module): - """ - Two layer LSTM + a linear output layer. The model assumes equal time - intervals across the whole input sequence, so time stamps are ignored. - - :meta private: - """ - - def __init__(self, nhid=51): - """ - :param nhid: number of hidden neurons in each of the LSTM cells - """ - super().__init__() - self.nhid = nhid - self.add_module("lstm1", nn.LSTMCell(1, self.nhid)) - self.add_module("lstm2", nn.LSTMCell(self.nhid, self.nhid)) - self.add_module("linear", nn.Linear(self.nhid, 1)) - self.h_t, self.c_t, self.h_t2, self.c_t2 = None, None, None, None - - def forward(self, input): - outputs = [] - self.reset(bsz=input.size(0)) - - for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): - self.h_t, self.c_t = self.lstm1(input_t, (self.h_t, self.c_t)) - self.h_t2, self.c_t2 = self.lstm2(self.h_t, (self.h_t2, self.c_t2)) - output = self.linear(self.h_t2) - outputs += [output] - - outputs = torch.stack(outputs, 1).squeeze(2) - return outputs - - def generate(self, input): - self.h_t, self.c_t = self.lstm1(input, (self.h_t, self.c_t)) - self.h_t2, self.c_t2 = self.lstm2(self.h_t, (self.h_t2, self.c_t2)) - output = self.linear(self.h_t2) - return torch.stack([output], 1).squeeze(2) - - def reset(self, bsz): - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - self.h_t = torch.zeros(bsz, self.nhid, dtype=torch.float, device=device) - self.c_t = torch.zeros(bsz, self.nhid, dtype=torch.float, device=device) - self.h_t2 = torch.zeros(bsz, self.nhid, dtype=torch.float, device=device) - self.c_t2 = torch.zeros(bsz, self.nhid, dtype=torch.float, device=device) - - -class _LSTMMultiScale(nn.Module): - """ - Multi-Scale LSTM Modeling - the model decomposes the input sequence using different granularities specified in strides - for each granularity it models the sequence using a two-layer LSTM - then the output from all the models are summed up to produce the result - - :meta private: - """ - - def __init__(self, strides=(1, 16, 32), nhid=51): - """ - :param strides: an iterable of strides - :param nhid: number of hidden neurons used in the LSTM cell - """ - super().__init__() - self.strides = strides - self.nhid = nhid - self.rnns = nn.ModuleList([_LSTMBase(nhid=self.nhid) for _ in strides]) - - def forward(self, input, future=0): - """ - :param input: batch_size * sequence_length - :param future: number of future steps for forecasting - :return: the predicted values including both 1-step predictions and the future step predictions - """ - outputs = [rnn(input[:, ::stride]) for stride, rnn in zip(self.strides, self.rnns)] - batch_sz, dim = outputs[0].shape - preds = [ - output.view(batch_sz, -1, 1).repeat(1, 1, stride).view(batch_sz, -1)[:, :dim] - for output, stride in zip(outputs, self.strides) - ] - - outputs = torch.stack(preds, dim=2).sum(dim=2) - futures = [] - prev = outputs[:, -1].view(batch_sz, -1) - - preds = [x[:, -1].view(batch_sz, -1) for x in preds] - - for i in range(future): - for j, (stride, rnn) in enumerate(zip(self.strides, self.rnns)): - if (i + dim) % stride == 0: - preds[j] = rnn.generate(prev) - - prev = torch.stack(preds, dim=2).sum(dim=2) - futures.append(prev) - futures = torch.cat(futures, dim=1) - return torch.cat([outputs, futures], dim=1) - - -def auto_stride(time_stamps, resolution=48): - """ - automatically set the sequence stride - experiments show LSTM does not work when the input sequence has super long period - in this case we may need to subsample the sequence so that the period is not too long - this function returns a stride suitable for LSTM modeling given the model period is daily. - - :param time_stamps: a list of UTC timestamps (in seconds) - :param resolution: maximum number of points in each day. (default to 48 so that it is a 30 min prediction) - :return: the selected stride - - :meta private: - """ - day_delta = datetime.timedelta(days=1).total_seconds() - start_day = bisect.bisect_left(time_stamps, time_stamps[-1] - day_delta) - day_stamps = len(time_stamps) - start_day - stride = day_stamps // resolution - return stride - - -class LSTM(ForecasterBase): - """ - LSTM forecaster: this assume the input time series has equal intervals across all its values - so that we can use sequence modeling to make forecast. - """ - - config_class = LSTMConfig - - def __init__(self, config: LSTMConfig): - super().__init__(config) - self.model = _LSTMMultiScale(strides=config.model_strides, nhid=config.nhid) - if torch.cuda.is_available(): - self.model.cuda() - self.optimizer = None - self.seq_len = None - self._forecast_vals = [0.0 for _ in range(self.max_forecast_steps)] - - @property - def require_even_sampling(self) -> bool: - return True - - @property - def _default_train_config(self): - return LSTMTrainConfig() - - def _train(self, train_data: pd.DataFrame, train_config: LSTMTrainConfig = None): - train_data = train_data[self.target_name] - train_values = train_data.values - - valid_len = int(np.ceil(len(train_data) * train_config.valid_split)) - - stride = train_config.data_stride - self.seq_len = train_config.seq_len - - # Get initial time & update the timedelta based on the stride - i0 = (len(train_data) - 1) % stride - t0 = to_timestamp(train_data.index[i0]) - self.last_train_time = train_data.index[-1] - self.timedelta = (train_data.index[1] - train_data.index[0]) * stride - - ############# - train_scores = train_values[:-valid_len] - _train_data = Corpus(sequence=train_scores, seq_len=self.seq_len, stride=stride) - train_dataloader = DataLoader(_train_data, batch_size=train_config.batch_size, shuffle=True) - - ############### - valid_scores = train_values[-valid_len:] - _valid_data = Corpus(sequence=valid_scores, seq_len=self.seq_len, stride=stride) - valid_dataloader = DataLoader(_valid_data, batch_size=train_config.batch_size, shuffle=False) - - ################ - no_progress_count = 0 - loss_best = 1e20 - self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=train_config.lr, momentum=0.9) - - for epoch in range(1, train_config.epochs + 1): - self.model.train() - total_loss = 0 - with tqdm(total=len(train_dataloader)) as pbar: - for batch_idx, batch in enumerate(train_dataloader): - if torch.cuda.is_available(): - batch = batch.cuda() - self.optimizer.zero_grad() - out = self.model(batch[:, : -(self.max_forecast_steps + 1)], future=self.max_forecast_steps) - loss = F.l1_loss(out, batch[:, 1:]) - loss.backward() - - self.optimizer.step() - pbar.update(1) - total_loss += loss.item() - loss = total_loss / (batch_idx + 1) - pbar.set_description(f"Epoch {epoch}|mae={loss:.4f}") - - # Validate model (n-step prediction) after this epoch - loss, count = 0, 0 - self.model.eval() - with torch.no_grad(): - for batch in valid_dataloader: - if torch.cuda.is_available(): - batch = batch.cuda() - feat = batch[:, : -(self.max_forecast_steps + 1)] - target = batch[:, -self.max_forecast_steps :] - out = self.model(feat, future=self.max_forecast_steps) - out = out[:, -self.max_forecast_steps :] - loss += F.l1_loss(out, target, reduction="sum").item() - count += target.shape[0] * target.shape[1] - - loss_eval = loss / count - logger.info(f"val |mae={loss_eval:.4f}") - - if loss_eval < loss_best: - logger.info(f"saving model |epoch={epoch} |mae={loss_eval:.4f}") - dirname = os.path.dirname(train_config.checkpoint_file) - if len(dirname) > 0: - os.makedirs(dirname, exist_ok=True) - torch.save(self.model.state_dict(), train_config.checkpoint_file) - loss_best = loss_eval - else: - no_progress_count += 1 - - if no_progress_count > 64: - logger.info("Dividing learning rate by 10") - self.optimizer.param_groups[0]["lr"] /= 10.0 - no_progress_count = 0 - - state_dict = torch.load(train_config.checkpoint_file, map_location=lambda storage, loc: storage) - os.remove(train_config.checkpoint_file) - self.model: _LSTMMultiScale - self.model.load_state_dict(state_dict) - for rnn in self.model.rnns: - rnn.h_t_default = rnn.h_t - rnn.c_t_default = rnn.c_t - rnn.h_t2_default = rnn.h_t2 - rnn.c_t2.default = rnn.c_t2 - - if not isinstance(self.transform, TransformSequence): - self.transform = TransformSequence([self.transform]) - done = False - for f in self.transform.transforms: - if isinstance(f, TemporalResample): - f.granularity = self.timedelta - f.origin = t0 - f.trainable_granularity = False - done = True - if not done: - self.transform.append(TemporalResample(granularity=self.timedelta, origin=t0, trainable_granularity=False)) - - # FORECASTING: forecast for next n steps using lstm model. - # Since we've updated the transform's granularity, re-apply it on - # the original train data (self.train_data) before proceeding. - ts = self.transform(self.train_data).univariates[self.target_name] - vals = torch.tensor([ts.np_values], dtype=torch.float) - if torch.cuda.is_available(): - vals = vals.cuda() - - with torch.no_grad(): - n = self.max_forecast_steps - preds = self.model(vals[:, :-n], future=n).squeeze().tolist() - self._forecast_vals = self.model(vals, future=n).squeeze().tolist()[-n:] - - return pd.DataFrame(preds, index=ts.index, columns=[self.target_name]), None - - def _forecast( - self, time_stamps: List[int], time_series_prev: pd.DataFrame = None, return_prev=False - ) -> Tuple[pd.DataFrame, None]: - n = len(time_stamps) - - if time_series_prev is None: - yhat = self._forecast_vals[:n] - yhat = pd.DataFrame(yhat, index=to_pd_datetime(time_stamps), columns=[self.target_name]) - return yhat, None - - # TODO: should we truncate time_series_prev to just the last - # (self.seq_len - self.max_forecast_steps) time steps? - # This would better match the training distribution - time_series_prev = time_series_prev.iloc[:, self.target_seq_index] - vals = torch.tensor([time_series_prev.values], dtype=torch.float) - if torch.cuda.is_available(): - vals = vals.cuda() - self.model.cuda() - with torch.no_grad(): - yhat = self.model(vals, future=n).squeeze().tolist() - - if return_prev: - time_stamps = to_timestamp(time_series_prev.index) + time_stamps - else: - yhat = yhat[-n:] - - yhat = pd.DataFrame(yhat, index=to_pd_datetime(time_stamps), columns=[self.target_name]) - return yhat, None diff --git a/merlion/models/forecast/sarima.py b/merlion/models/forecast/sarima.py index 897167edc..eae63929b 100644 --- a/merlion/models/forecast/sarima.py +++ b/merlion/models/forecast/sarima.py @@ -14,7 +14,7 @@ import numpy as np import pandas as pd -from statsmodels.tsa.arima.model import ARIMA as sm_Sarima +import statsmodels.api as sm from merlion.models.automl.seasonality import SeasonalityModel from merlion.models.forecast.base import ForecasterExogBase, ForecasterExogConfig @@ -61,7 +61,11 @@ def __init__(self, config: SarimaConfig): @property def require_even_sampling(self) -> bool: - return True + return False + + @property + def _default_train_config(self): + return dict(enforce_stationarity=False, enforce_invertibility=False) @property def order(self) -> Tuple[int, int, int]: @@ -99,10 +103,10 @@ def _train_with_exog( train_config = train_config or {} with warnings.catch_warnings(): warnings.simplefilter("ignore") - model = sm_Sarima( + model = sm.tsa.SARIMAX( train_data, exog=exog_data, order=self.order, seasonal_order=self.seasonal_order, **train_config ) - self.model = model.fit(method_kwargs={"disp": 0}) + self.model = model.fit(disp=0) # FORECASTING: forecast for next n steps using Sarima model self._last_val = train_data[-1] diff --git a/merlion/models/forecast/transformer.py b/merlion/models/forecast/transformer.py new file mode 100644 index 000000000..6c9e2fb89 --- /dev/null +++ b/merlion/models/forecast/transformer.py @@ -0,0 +1,211 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Implementation of Transformer for time series data. +""" +import copy +import logging +import math + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from typing import List, Optional, Tuple, Union +from abc import abstractmethod + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + + +from merlion.models.base import NormalizingConfig +from merlion.models.deep_base import TorchModel +from merlion.models.forecast.deep_base import DeepForecasterConfig, DeepForecaster + +from merlion.models.utils.nn_modules import FullAttention, AttentionLayer, DataEmbedding, ConvLayer + +from merlion.models.utils.nn_modules.enc_dec_transformer import Encoder, EncoderLayer, Decoder, DecoderLayer + +from merlion.utils.misc import initializer + +logger = logging.getLogger(__name__) + + +class TransformerConfig(DeepForecasterConfig, NormalizingConfig): + """ + Transformer for time series forecasting. + Code adapted from https://github.com/thuml/Autoformer. + """ + + @initializer + def __init__( + self, + n_past, + max_forecast_steps: int = None, + encoder_input_size: int = None, + decoder_input_size: int = None, + num_encoder_layers: int = 2, + num_decoder_layers: int = 1, + start_token_len: int = 0, + factor: int = 3, + model_dim: int = 512, + embed: str = "timeF", + dropout: float = 0.05, + activation: str = "gelu", + n_heads: int = 8, + fcn_dim: int = 2048, + distil: bool = True, + **kwargs + ): + """ + :param n_past: # of past steps used for forecasting future. + :param max_forecast_steps: Max # of steps we would like to forecast for. + :param encoder_input_size: Input size of encoder. If ``encoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param decoder_input_size: Input size of decoder. If ``decoder_input_size = None``, + then the model will automatically use ``config.dim``, which is the dimension of the input data. + :param num_encoder_layers: Number of encoder layers. + :param num_decoder_layers: Number of decoder layers. + :param start_token_len: Length of start token for deep transformer encoder-decoder based models. + The start token is similar to the special tokens for NLP models (e.g., bos, sep, eos tokens). + :param factor: Attention factor. + :param model_dim: Dimension of the model. + :param embed: Time feature encoding type, options include ``timeF``, ``fixed`` and ``learned``. + :param dropout: dropout rate. + :param activation: Activation function, can be ``gelu``, ``relu``, ``sigmoid``, etc. + :param n_heads: Number of heads of the model. + :param fcn_dim: Hidden dimension of the MLP layer in the model. + :param distil: whether to use distilling in the encoder of the model. + """ + + super().__init__(n_past=n_past, max_forecast_steps=max_forecast_steps, **kwargs) + + +class TransformerModel(TorchModel): + """ + Implementaion of Transformer deep torch model. + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config) + + if config.dim is not None: + config.encoder_input_size = config.dim if config.encoder_input_size is None else config.encoder_input_size + config.decoder_input_size = ( + config.encoder_input_size if config.decoder_input_size is None else config.decoder_input_size + ) + + config.c_out = config.encoder_input_size + + self.n_past = config.n_past + self.start_token_len = config.start_token_len + self.max_forecast_steps = config.max_forecast_steps + + self.enc_embedding = DataEmbedding( + config.encoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + self.dec_embedding = DataEmbedding( + config.decoder_input_size, config.model_dim, config.embed, config.ts_encoding, config.dropout + ) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.fcn_dim, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_encoder_layers) + ], + norm_layer=torch.nn.LayerNorm(config.model_dim), + ) + + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + FullAttention(True, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + AttentionLayer( + FullAttention(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.model_dim, + config.n_heads, + ), + config.model_dim, + config.fcn_dim, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.num_decoder_layers) + ], + norm_layer=torch.nn.LayerNorm(config.model_dim), + projection=nn.Linear(config.model_dim, config.c_out, bias=True), + ) + + def forward( + self, + past, + past_timestamp, + future_timestamp, + enc_self_mask=None, + dec_self_mask=None, + dec_enc_mask=None, + **kwargs + ): + config = self.config + + start_token = past[:, (past.shape[1] - self.start_token_len) :] + dec_inp = torch.zeros( + past.shape[0], self.max_forecast_steps, config.decoder_input_size, dtype=torch.float, device=self.device + ) + dec_inp = torch.cat([start_token, dec_inp], dim=1) + + future_timestamp = torch.cat( + [past_timestamp[:, (past_timestamp.shape[1] - self.start_token_len) :], future_timestamp], dim=1 + ) + + enc_out = self.enc_embedding(past, past_timestamp) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(dec_inp, future_timestamp) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + + if self.config.target_seq_index is not None: + return dec_out[:, -self.max_forecast_steps :, :1] + else: + return dec_out[:, -self.max_forecast_steps :, :] + + +class TransformerForecaster(DeepForecaster): + """ + Implementaion of Transformer deep forecaster + """ + + config_class = TransformerConfig + deep_model_class = TransformerModel + + def __init__(self, config: TransformerConfig): + super().__init__(config) diff --git a/merlion/models/utils/autosarima_utils.py b/merlion/models/utils/autosarima_utils.py index a5b3ef06c..9dc537852 100644 --- a/merlion/models/utils/autosarima_utils.py +++ b/merlion/models/utils/autosarima_utils.py @@ -105,6 +105,8 @@ def _fit_sarima_model(y, order, seasonal_order, trend, method, maxiter, informat seasonal_order=seasonal_order, trend=trend, validate_specification=False, + enforce_stationarity=False, + enforce_invertibility=False, **kwargs, ) try: @@ -185,7 +187,7 @@ def detect_maxiter_sarima_model(y, d, D, m, method, information_criterion, exog= maxiter = 10 ic = np.inf model_spec = sm.tsa.SARIMAX( - endog=y, exog=exog, order=order, seasonal_order=seasonal_order, trend="c", validate_specification=False + endog=y, exog=exog, order=order, seasonal_order=seasonal_order, trend="c", validate_specification=False, enforce_stationarity=False, enforce_invertibility=False ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/merlion/models/utils/early_stopping.py b/merlion/models/utils/early_stopping.py new file mode 100644 index 000000000..650514dd0 --- /dev/null +++ b/merlion/models/utils/early_stopping.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Earlying Stopping +""" +import logging + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +import numpy as np + + +logger = logging.getLogger(__name__) + + +class EarlyStopping: + """ + Early stopping for deep model training + """ + + def __init__(self, patience=7, delta=0): + """ + :param patience: Number of epochs with no improvement after which training will be stopped. + :param delta: Minimum change in the monitored quantity to qualify as an improvement, + i.e. an absolute change of less than min_delta, will count as no improvement. + """ + + self.patience = patience + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.best_model_state_dict = None + + def __call__(self, val_loss, model): + score = -val_loss + if self.best_score is None: + self.best_score = score + self.save_best_state_and_dict(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + logger.info(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_best_state_and_dict(val_loss, model) + self.counter = 0 + + def save_best_state_and_dict(self, val_loss, model): + self.best_model_state_dict = model.state_dict() + + self.val_loss_min = val_loss + + def load_best_model(self, model): + model.load_state_dict(self.best_model_state_dict) diff --git a/merlion/models/utils/nn_modules/__init__.py b/merlion/models/utils/nn_modules/__init__.py new file mode 100644 index 000000000..dbdd65f6b --- /dev/null +++ b/merlion/models/utils/nn_modules/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +from .blocks import ( + AutoCorrelation, + SeasonalLayernorm, + SeriesDecomposeBlock, + MovingAverageBlock, + FullAttention, + ProbAttention, +) +from .layers import AutoCorrelationLayer, ConvLayer, AttentionLayer + +from .embed import DataEmbedding, DataEmbeddingWoPos, ETSEmbedding diff --git a/merlion/models/utils/nn_modules/blocks.py b/merlion/models/utils/nn_modules/blocks.py new file mode 100644 index 000000000..5750d9110 --- /dev/null +++ b/merlion/models/utils/nn_modules/blocks.py @@ -0,0 +1,416 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import math +import numpy as np + +try: + import torch + import torch.nn as nn + import torch.fft as fft + import torch.nn.functional as F + from einops import rearrange, reduce, repeat +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from math import sqrt +from scipy.fftpack import next_fast_len + + +class AutoCorrelation(nn.Module): + """ + AutoCorrelation Mechanism with the following two phases: + (1) period-based dependencies discovery + (2) time delay aggregation + This block can replace the self-attention family mechanism seamlessly. + """ + + def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): + super(AutoCorrelation, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def time_delay_agg_training(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the training phase. + """ + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] + weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + pattern = torch.roll(tmp_values, -int(index[i]), -1) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + ) + return delays_agg + + def time_delay_agg_inference(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the inference phase. + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = ( + torch.arange(length) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch, head, channel, 1) + .to(values.device) + ) + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + weights, delay = torch.topk(mean_value, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + ) + return delays_agg + + def time_delay_agg_full(self, values, corr): + """ + Standard version of Autocorrelation + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = ( + torch.arange(length) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch, head, channel, 1) + .to(values.device) + ) + # find top k + top_k = int(self.factor * math.log(length)) + weights, delay = torch.topk(corr, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[..., i].unsqueeze(-1) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) + return delays_agg + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, : (L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + + # period-based dependencies + q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) + k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) + res = q_fft * torch.conj(k_fft) + corr = torch.fft.irfft(res, dim=-1) + + # time delay agg + if self.training: + V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + else: + V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + + if self.output_attention: + return (V.contiguous(), corr.permute(0, 3, 1, 2)) + else: + return (V.contiguous(), None) + + +# Attention building blocks + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + + +class ProbMask: + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask + + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + + +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert L_Q == L_V # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul( + attn, V + ).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward(self, queries, keys, values, attn_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2, 1) + keys = keys.transpose(2, 1) + values = values.transpose(2, 1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1.0 / sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) + + return context.contiguous(), attn + + +class SeasonalLayernorm(nn.Module): + """ + Special designed layernorm for the seasonal part + Build for Autoformer + """ + + def __init__(self, channels): + super(SeasonalLayernorm, self).__init__() + self.layernorm = nn.LayerNorm(channels) + + def forward(self, x): + x_hat = self.layernorm(x) + bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) + return x_hat - bias + + +class MovingAverageBlock(nn.Module): + """ + Moving average block to highlight the trend of time series + Build for Autoformer + """ + + def __init__(self, kernel_size, stride): + super(MovingAverageBlock, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class SeriesDecomposeBlock(nn.Module): + """ + Series decomposition block + Build for Autoformer + """ + + def __init__(self, kernel_size): + super(SeriesDecomposeBlock, self).__init__() + self.moving_avg = MovingAverageBlock(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +def conv1d_fft(f, g, dim=-1): + N = f.size(dim) + M = g.size(dim) + + fast_len = next_fast_len(N + M - 1) + + F_f = fft.rfft(f, fast_len, dim=dim) + F_g = fft.rfft(g, fast_len, dim=dim) + + F_fg = F_f * F_g.conj() + out = fft.irfft(F_fg, fast_len, dim=dim) + out = out.roll((-1,), dims=(dim,)) + idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) + out = out.index_select(dim, idx) + + return out + + +class ExponentialSmoothing(nn.Module): + def __init__(self, dim, nhead, dropout=0.1, aux=False): + super().__init__() + self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1)) + self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim)) + self.dropout = nn.Dropout(dropout) + if aux: + self.aux_dropout = nn.Dropout(dropout) + + def forward(self, values, aux_values=None): + b, t, h, d = values.shape + + init_weight, weight = self.get_exponential_weight(t) + output = conv1d_fft(self.dropout(values), weight, dim=1) + output = init_weight * self.v0 + output + + if aux_values is not None: + aux_weight = weight / (1 - self.weight) * self.weight + aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight) + output = output + aux_output + + return output + + def get_exponential_weight(self, T): + # Generate array [0, 1, ..., T-1] + powers = torch.arange(T, dtype=torch.float, device=self.weight.device) + + # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0] + weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,))) + + # \alpha^t for all t = 1, 2, ..., T + init_weight = self.weight ** (powers + 1) + + return rearrange(init_weight, "h t -> 1 t h 1"), rearrange(weight, "h t -> 1 t h 1") + + @property + def weight(self): + return torch.sigmoid(self._smoothing_weight) diff --git a/merlion/models/utils/nn_modules/embed.py b/merlion/models/utils/nn_modules/embed.py new file mode 100644 index 000000000..80386a789 --- /dev/null +++ b/merlion/models/utils/nn_modules/embed.py @@ -0,0 +1,164 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[:, : x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode="circular", bias=False + ) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type="fixed", freq="h"): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding + if freq == "t": + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + + minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type="timeF", freq="h"): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + return self.dropout(x) + + +class DataEmbeddingWoPos(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super(DataEmbeddingWoPos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + + +class ETSEmbedding(nn.Module): + def __init__(self, c_in, d_model, dropout=0.1): + super().__init__() + self.conv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=2, bias=False) + self.dropout = nn.Dropout(p=dropout) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + x = self.conv(x.permute(0, 2, 1))[..., :-2] + return self.dropout(x.transpose(1, 2)) diff --git a/merlion/models/utils/nn_modules/enc_dec_autoformer.py b/merlion/models/utils/nn_modules/enc_dec_autoformer.py new file mode 100644 index 000000000..048028aa5 --- /dev/null +++ b/merlion/models/utils/nn_modules/enc_dec_autoformer.py @@ -0,0 +1,144 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from merlion.models.utils.nn_modules.blocks import SeriesDecomposeBlock, MovingAverageBlock + + +class EncoderLayer(nn.Module): + """ + Autoformer encoder layer with the progressive decomposition architecture + """ + + def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = SeriesDecomposeBlock(moving_avg) + self.decomp2 = SeriesDecomposeBlock(moving_avg) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) + x = x + self.dropout(new_x) + x, _ = self.decomp1(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + res, _ = self.decomp2(x + y) + return res, attn + + +class Encoder(nn.Module): + """ + Autoformer encoder + """ + + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + """ + Autoformer decoder layer with the progressive decomposition architecture + """ + + def __init__( + self, self_attention, cross_attention, d_model, c_out, d_ff=None, moving_avg=25, dropout=0.1, activation="relu" + ): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = SeriesDecomposeBlock(moving_avg) + self.decomp2 = SeriesDecomposeBlock(moving_avg) + self.decomp3 = SeriesDecomposeBlock(moving_avg) + self.dropout = nn.Dropout(dropout) + self.projection = nn.Conv1d( + in_channels=d_model, + out_channels=c_out, + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + bias=False, + ) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) + x, trend1 = self.decomp1(x) + x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) + x, trend2 = self.decomp2(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + x, trend3 = self.decomp3(x + y) + + residual_trend = trend1 + trend2 + trend3 + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) + return x, residual_trend + + +class Decoder(nn.Module): + """ + Autoformer decoder + """ + + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend diff --git a/merlion/models/utils/nn_modules/enc_dec_etsformer.py b/merlion/models/utils/nn_modules/enc_dec_etsformer.py new file mode 100644 index 000000000..b52ede317 --- /dev/null +++ b/merlion/models/utils/nn_modules/enc_dec_etsformer.py @@ -0,0 +1,150 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import math +import random +import numpy as np + +try: + import torch + import torch.nn as nn + import torch.fft as fft + import torch.nn.functional as F + from einops import rearrange, reduce, repeat +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +from merlion.models.utils.nn_modules.layers import GrowthLayer, FourierLayer, LevelLayer, DampingLayer, MLPLayer + + +class EncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + c_out, + seq_len, + pred_len, + k, + dim_feedforward=None, + dropout=0.1, + activation="sigmoid", + layer_norm_eps=1e-5, + output_attention=False, + ): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.c_out = c_out + self.seq_len = seq_len + self.pred_len = pred_len + dim_feedforward = dim_feedforward or 4 * d_model + self.dim_feedforward = dim_feedforward + + self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout, output_attention=output_attention) + self.seasonal_layer = FourierLayer(d_model, pred_len, k=k, output_attention=output_attention) + self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) + + # Implementation of Feedforward model + self.ff = MLPLayer(d_model, dim_feedforward, dropout=dropout, activation=activation) + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, res, level, attn_mask=None): + season, season_attn = self._season_block(res) + res = res - season[:, : -self.pred_len] + growth, growth_attn = self._growth_block(res) + res = self.norm1(res - growth[:, 1:]) + res = self.norm2(res + self.ff(res)) + + level = self.level_layer(level, growth[:, :-1], season[:, : -self.pred_len]) + + return res, level, growth, season, season_attn, growth_attn + + def _growth_block(self, x): + x, growth_attn = self.growth_layer(x) + return self.dropout1(x), growth_attn + + def _season_block(self, x): + x, season_attn = self.seasonal_layer(x) + return self.dropout2(x), season_attn + + +class Encoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.layers = nn.ModuleList(layers) + + def forward(self, res, level, attn_mask=None): + growths = [] + seasons = [] + season_attns = [] + growth_attns = [] + for layer in self.layers: + res, level, growth, season, season_attn, growth_attn = layer(res, level, attn_mask=None) + growths.append(growth) + seasons.append(season) + season_attns.append(season_attn) + growth_attns.append(growth_attn) + + return level, growths, seasons, season_attns, growth_attns + + +class DecoderLayer(nn.Module): + def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.c_out = c_out + self.pred_len = pred_len + self.output_attention = output_attention + + self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention) + self.dropout1 = nn.Dropout(dropout) + + def forward(self, growth, season): + growth_horizon, growth_damping = self.growth_damping(growth[:, -1:]) + growth_horizon = self.dropout1(growth_horizon) + + seasonal_horizon = season[:, -self.pred_len :] + + if self.output_attention: + return growth_horizon, seasonal_horizon, growth_damping + return growth_horizon, seasonal_horizon, None + + +class Decoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.d_model = layers[0].d_model + self.c_out = layers[0].c_out + self.pred_len = layers[0].pred_len + self.nhead = layers[0].nhead + + self.layers = nn.ModuleList(layers) + self.pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, growths, seasons): + growth_repr = [] + season_repr = [] + growth_dampings = [] + + for idx, layer in enumerate(self.layers): + growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx]) + growth_repr.append(growth_horizon) + season_repr.append(season_horizon) + growth_dampings.append(growth_damping) + growth_repr = sum(growth_repr) + season_repr = sum(season_repr) + return self.pred(growth_repr), self.pred(season_repr), growth_dampings diff --git a/merlion/models/utils/nn_modules/enc_dec_transformer.py b/merlion/models/utils/nn_modules/enc_dec_transformer.py new file mode 100644 index 000000000..758de5ce5 --- /dev/null +++ b/merlion/models/utils/nn_modules/enc_dec_transformer.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +import math + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x diff --git a/merlion/models/utils/nn_modules/layers.py b/merlion/models/utils/nn_modules/layers.py new file mode 100644 index 000000000..044013875 --- /dev/null +++ b/merlion/models/utils/nn_modules/layers.py @@ -0,0 +1,292 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +try: + import torch + import torch.nn as nn + import torch.fft as fft + import torch.nn.functional as F + from einops import rearrange, reduce, repeat +except ImportError as e: + err = ( + "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[deep-learning]` or " + "`pip install `salesforce-merlion[all]`" + ) + raise ImportError(str(e) + ". " + err) + +import numpy as np +import math +from math import sqrt +import os + +from scipy.fftpack import next_fast_len +from merlion.models.utils.nn_modules.blocks import ExponentialSmoothing, conv1d_fft + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, out_channels=c_in, kernel_size=3, padding=2, padding_mode="circular" + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class MLPLayer(nn.Module): + def __init__(self, d_model, dim_feedforward, dropout=0.1, activation="sigmoid"): + # Implementation of Feedforward model + super().__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) + self.dropout1 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) + self.dropout2 = nn.Dropout(dropout) + self.activation = getattr(torch, activation) + + def forward(self, x): + x = self.linear2(self.dropout1(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class AutoCorrelationLayer(nn.Module): + def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None): + super(AutoCorrelationLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_correlation = correlation + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_correlation(queries, keys, values, attn_mask) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +# Attention Layers +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention(queries, keys, values, attn_mask) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +# layers from ETS +class GrowthLayer(nn.Module): + def __init__(self, d_model, nhead, d_head=None, dropout=0.1, output_attention=False): + super().__init__() + self.d_head = d_head or (d_model // nhead) + self.d_model = d_model + self.nhead = nhead + self.output_attention = output_attention + + self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) + self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) + self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout) + self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model) + + assert self.d_head * self.nhead == self.d_model, "d_model must be divisible by nhead" + + def forward(self, inputs): + """ + :param inputs: shape: (batch, seq_len, dim) + :return: shape: (batch, seq_len, dim) + """ + b, t, d = inputs.shape + values = self.in_proj(inputs).view(b, t, self.nhead, -1) + values = torch.cat([repeat(self.z0, "h d -> b 1 h d", b=b), values], dim=1) + values = values[:, 1:] - values[:, :-1] + out = self.es(values) + out = torch.cat([repeat(self.es.v0, "1 1 h d -> b 1 h d", b=b), out], dim=1) + out = rearrange(out, "b t h d -> b t (h d)") + out = self.out_proj(out) + + if self.output_attention: + return out, self.es.get_exponential_weight(t)[1] + return out, None + + +class FourierLayer(nn.Module): + def __init__(self, d_model, pred_len, k=None, low_freq=1, output_attention=False): + super().__init__() + self.d_model = d_model + self.pred_len = pred_len + self.k = k + self.low_freq = low_freq + self.output_attention = output_attention + + def forward(self, x): + """x: (b, t, d)""" + + if self.output_attention: + return self.dft_forward(x) + + b, t, d = x.shape + x_freq = fft.rfft(x, dim=1) + + if t % 2 == 0: + x_freq = x_freq[:, self.low_freq : -1] + f = fft.rfftfreq(t)[self.low_freq : -1] + else: + x_freq = x_freq[:, self.low_freq :] + f = fft.rfftfreq(t)[self.low_freq :] + + x_freq, index_tuple = self.topk_freq(x_freq) + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)) + f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) + + return self.extrapolate(x_freq, f, t), None + + def extrapolate(self, x_freq, f, t): + x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) + f = torch.cat([f, -f], dim=1) + t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float), "t -> () () t ()").to(x_freq.device) + + amp = rearrange(x_freq.abs() / t, "b f d -> b f () d") + phase = rearrange(x_freq.angle(), "b f d -> b f () d") + + x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) + + return reduce(x_time, "b f t d -> b t d", "sum") + + def topk_freq(self, x_freq): + values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) + mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2))) + index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) + x_freq = x_freq[index_tuple] + + return x_freq, index_tuple + + def dft_forward(self, x): + T = x.size(1) + + dft_mat = fft.fft(torch.eye(T)) + i, j = torch.meshgrid(torch.arange(self.pred_len + T), torch.arange(T)) + omega = np.exp(2 * math.pi * 1j / T) + idft_mat = (np.power(omega, i * j) / T).cfloat() + + x_freq = torch.einsum("ft,btd->bfd", [dft_mat, x.cfloat()]) + + if T % 2 == 0: + x_freq = x_freq[:, self.low_freq : T // 2] + else: + x_freq = x_freq[:, self.low_freq : T // 2 + 1] + + _, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) + indices = indices + self.low_freq + indices = torch.cat([indices, -indices], dim=1) + + dft_mat = repeat(dft_mat, "f t -> b f t d", b=x.shape[0], d=x.shape[-1]) + idft_mat = repeat(idft_mat, "t f -> b t f d", b=x.shape[0], d=x.shape[-1]) + + mesh_a, mesh_b = torch.meshgrid(torch.arange(x.size(0)), torch.arange(x.size(2))) + + dft_mask = torch.zeros_like(dft_mat) + dft_mask[mesh_a, indices, :, mesh_b] = 1 + dft_mat = dft_mat * dft_mask + + idft_mask = torch.zeros_like(idft_mat) + idft_mask[mesh_a, :, indices, mesh_b] = 1 + idft_mat = idft_mat * idft_mask + + attn = torch.einsum("bofd,bftd->botd", [idft_mat, dft_mat]).real + return torch.einsum("botd,btd->bod", [attn, x]), rearrange(attn, "b o t d -> b d o t") + + +class LevelLayer(nn.Module): + def __init__(self, d_model, c_out, dropout=0.1): + super().__init__() + self.d_model = d_model + self.c_out = c_out + + self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True) + self.growth_pred = nn.Linear(self.d_model, self.c_out) + self.season_pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, level, growth, season): + b, t, _ = level.shape + growth = self.growth_pred(growth).view(b, t, self.c_out, 1) + season = self.season_pred(season).view(b, t, self.c_out, 1) + growth = growth.view(b, t, self.c_out, 1) + season = season.view(b, t, self.c_out, 1) + level = level.view(b, t, self.c_out, 1) + out = self.es(level - season, aux_values=growth) + out = rearrange(out, "b t h d -> b t (h d)") + return out + + +class DampingLayer(nn.Module): + def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False): + super().__init__() + self.pred_len = pred_len + self.nhead = nhead + self.output_attention = output_attention + self._damping_factor = nn.Parameter(torch.randn(1, nhead)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = repeat(x, "b 1 d -> b t d", t=self.pred_len) + b, t, d = x.shape + + powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 + powers = powers.view(self.pred_len, 1) + damping_factors = self.damping_factor**powers + damping_factors = damping_factors.cumsum(dim=0) + x = x.view(b, t, self.nhead, -1) + x = self.dropout(x) * damping_factors.unsqueeze(-1) + x = x.view(b, t, d) + if self.output_attention: + return x, damping_factors + return x, None + + @property + def damping_factor(self): + return torch.sigmoid(self._damping_factor) diff --git a/merlion/models/utils/rolling_window_dataset.py b/merlion/models/utils/rolling_window_dataset.py index 914f3303f..962dea494 100644 --- a/merlion/models/utils/rolling_window_dataset.py +++ b/merlion/models/utils/rolling_window_dataset.py @@ -4,11 +4,16 @@ # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # +""" +A rolling window dataset +""" import logging +import math import numpy as np from typing import Optional, Union import pandas as pd from merlion.utils.time_series import TimeSeries, to_pd_datetime +from merlion.models.utils.time_features import get_time_features logger = logging.getLogger(__name__) @@ -25,6 +30,9 @@ def __init__( ts_index: bool = False, batch_size: Optional[int] = 1, flatten: bool = True, + ts_encoding: Union[None, str] = None, + valid_fraction: float = 0.0, + validation: Union[bool, None] = False, seed: int = 0, ): """ @@ -55,6 +63,14 @@ def __init__( by default, Numpy array will handle the internal data workflow and Numpy array will be the output. :param batch_size: the number of windows to return in parallel. If ``None``, return the whole dataset. :param flatten: whether the output time series arrays should be flattened to 2 dimensions. + :param ts_encoding: whether the timestamp should be encoded to a float vector, which can be used + for training deep learning based time series models; if ``None``, the timestamp is not encoded. + If not ``None``, it represents the frequency for time features encoding options:[s:secondly, t:minutely, h:hourly, + d:daily, b:business days, w:weekly, m:monthly] + :param valid_fraction: Fraction of validation set splitted from training data. if ``valid_fraction = 0`` + or ``valid_fraction = 1``, we iterate over the entire dataset. + :param validation: Whether the data is from the validation set or not. if ``validation = None``, we iterate over + the entire dataset. """ assert isinstance( data, (TimeSeries, pd.DataFrame) @@ -81,22 +97,17 @@ def __init__( self.n_past = n_past self.shuffle = shuffle self.flatten = flatten + self.ts_encoding = ts_encoding self.target_seq_index = target_seq_index - if target_seq_index is None: - if n_future not in [0, 1]: - logger.warning( - "Since target_seq_index is None, we predict all univariates for this dataset. Currently, this is " - "only valid with 1-step lookahead (autoregressive forecasting) or 0-step lookahead (autoencoding). " - "Setting n_future = 1. If you are not expecting this behavior, set target_seq_index appropriately." - ) - n_future = 1 self.n_future = n_future self.ts_index = ts_index if ts_index: self.data = data.concat(exog_data, axis=1) if exog_data is not None else data - self.target = data if self.autoregressive else data.univariates[data.names[target_seq_index]].to_ts() + self.target = ( + data if self.target_seq_index is None else data.univariates[data.names[target_seq_index]].to_ts() + ) self.timestamp = to_pd_datetime(data.np_time_stamps) else: df = data.to_pd() if isinstance(data, TimeSeries) else data @@ -109,27 +120,100 @@ def __init__( self.data = np.concatenate((df.values, exog_vals), axis=1) self.data = np.concatenate((df.values, exog_df.values), axis=1) if exog_data is not None else df.values self.timestamp = df.index - self.target = df.values if self.autoregressive else df.values[:, target_seq_index] + self.target = df.values if self.target_seq_index is None else df.values[:, target_seq_index] - self.seed = seed + if self.ts_encoding: + self.timestamp = get_time_features(self.timestamp, self.ts_encoding) + + self._seed = seed + + self._valid = validation + self.valid_fraction = valid_fraction + + if valid_fraction <= 0.0 or valid_fraction >= 1.0 or (self.validation is None): + n_train = self.n_windows + else: + n_train = self.n_windows - math.ceil(self.n_windows * self.valid_fraction) + + data_indices = np.arange(self.n_windows) + + # use seed 0 to perturb the dataset + if shuffle: + data_indices = np.random.RandomState(seed).permutation(data_indices) + + self.train_indices = data_indices[:n_train] + self.valid_indices = data_indices[n_train:] + + @property + def validation(self): + """ + If set ``False``, we only provide access to the training windows; if set ``True``, + we only provide access to the validation windows. if set``None``, we iterate over + the entire dataset. + """ + return self._valid + + @validation.setter + def validation(self, valid: bool): + self._valid = valid + + @property + def seed(self): + """ + Set Random seed to perturb the training data + """ + return self._seed + + @seed.setter + def seed(self, seed: int): + """ + Set Random seed to perturb the training data + """ + self._seed = seed + + @property + def n_windows(self): + """ + Number of total slides windows + """ + return len(self.data) - self.n_past - self.n_future + 1 + + @property + def n_valid(self): + """ + Number of slides windows in validation set + """ + return len(self.valid_indices) @property - def autoregressive(self): - return self.target_seq_index is None + def n_train(self): + """ + Number of slides windows in training set + """ + return len(self.train_indices) @property def n_points(self): - return len(self.data) - self.n_past + 1 - self.n_future + n_train, n_valid = self.n_train, self.n_valid + return n_train + n_valid if self.validation is None else n_valid if self.validation else n_train def __len__(self): return int(np.ceil(self.n_points / self.batch_size)) if self.batch_size is not None else 1 def __iter__(self): batch = [] - if self.shuffle and self.batch_size is not None: - order = np.random.RandomState(self.seed).permutation(self.n_points) + + if self.validation is None: + order = sorted(np.concatenate((self.train_indices, self.valid_indices))) + if self.shuffle: + order = np.random.RandomState(self.seed).permutation(order) + elif self.validation: + order = self.valid_indices + elif self.shuffle and self.batch_size is not None: + order = np.random.RandomState(self.seed).permutation(self.train_indices) else: - order = range(self.n_points) + order = self.train_indices + for i in order: batch.append(self[i]) if self.batch_size is not None and len(batch) >= self.batch_size: @@ -144,21 +228,39 @@ def collate_batch(self, batch): # TODO: allow output shape to be specified as class parameter past, past_ts, future, future_ts = zip(*batch) past = np.stack(past) + past_ts = np.stack(past_ts) + if self.flatten: past = past.reshape((len(batch), -1)) - past_ts = np.stack(past_ts) + past_ts = past_ts.reshape((len(batch), -1)) if self.ts_encoding else past_ts + if future is not None: - future = np.stack(future).reshape((len(batch), -1)) + future = np.stack(future) + future = future.reshape((len(batch), -1)) if self.flatten else future + future_ts = np.stack(future_ts) + if self.flatten and self.ts_encoding: + future_ts = future_ts.reshape((len(batch), -1)) else: future, future_ts = None, None return past, past_ts, future, future_ts def __getitem__(self, idx): - assert 0 <= idx < self.n_points - idx_end = idx + self.n_past - past = self.data[idx:idx_end] - past_timestamp = self.timestamp[idx:idx_end] - future = self.target[idx_end : idx_end + self.n_future] - future_timestamp = self.timestamp[idx_end : idx_end + self.n_future] + if self.validation is None: + assert 0 <= idx < self.n_points + elif self.validation: + assert idx in self.valid_indices + else: + assert idx in self.train_indices + + past_start = idx + past_end = past_start + self.n_past + future_start = past_end + future_end = future_start + self.n_future + + past = self.data[past_start:past_end] + past_timestamp = self.timestamp[past_start:past_end] + future = self.target[future_start:future_end] + future_timestamp = self.timestamp[future_start:future_end] + return (past, future) if self.ts_index else (past, past_timestamp, future, future_timestamp) diff --git a/merlion/models/utils/time_features.py b/merlion/models/utils/time_features.py new file mode 100644 index 000000000..b663d8cd2 --- /dev/null +++ b/merlion/models/utils/time_features.py @@ -0,0 +1,162 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +Utils for converting pandas datetime to numerical vectors +""" +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class SecondOfMinute(TimeFeature): + """ + Second of minute encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """ + Minute of hour encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """ + Hour of day encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """ + Day of week encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """ + Day of month encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """ + Day of year encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """ + Month of year encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """ + Week of year encoded as value between [-0.5, 0.5] + """ + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + :param freq_str: Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + + :return: a list of time features that will be appropriate for the given frequency string. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + + +def get_time_features(dates: pd.DatetimeIndex, ts_encoding: str = "h"): + """ + Convert pandas Datetime to numerical vectors that can be used for training + """ + + features = np.vstack([feat(dates) for feat in time_features_from_frequency_str(ts_encoding)]) + return features.transpose(1, 0) diff --git a/setup.py b/setup.py index af7f8bafe..b785a070e 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ # optional dependencies extra_require = { "dashboard": ["dash[diskcache]>=2.4", "dash_bootstrap_components>=1.0", "diskcache"], - "deep-learning": ["torch>=1.1.0"], + "deep-learning": ["torch>=1.9.0", "einops>=0.4.0"], "spark": ["pyspark[sql]>=3"], } extra_require["all"] = sum(extra_require.values(), []) diff --git a/tests/anomaly/forecast_based/test_lstm.py b/tests/anomaly/forecast_based/test_lstm.py deleted file mode 100644 index 16e2b203a..000000000 --- a/tests/anomaly/forecast_based/test_lstm.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Copyright (c) 2022 salesforce.com, inc. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# -import datetime -import logging -import math -from os.path import abspath, dirname, join -import sys -import unittest - -import numpy as np - -from merlion.transform.resample import TemporalResample -from merlion.models.anomaly.forecast_based.lstm import LSTMDetector, LSTMTrainConfig, LSTMDetectorConfig -from merlion.models.forecast.lstm import auto_stride -from merlion.post_process.threshold import AggregateAlarms -from merlion.utils.time_series import TimeSeries -from merlion.utils.data_io import csv_to_time_series - -logger = logging.getLogger(__name__) -rootdir = dirname(dirname(dirname(dirname(abspath(__file__))))) - - -class TestLSTM(unittest.TestCase): - def test_full(self): - file_name = join(rootdir, "data", "example.csv") - - sequence = TemporalResample("15min")(csv_to_time_series(file_name, timestamp_unit="ms", data_cols=["kpi"])) - logger.info(f"Data looks like:\n{sequence[:5]}") - - time_stamps = sequence.univariates[sequence.names[0]].time_stamps - stride = auto_stride(time_stamps, resolution=12) - logger.info("stride = " + str(stride)) - - # 2 days of data for testing - test_delta = datetime.timedelta(days=2).total_seconds() - ts_train, ts_test = sequence.bisect(time_stamps[-1] - test_delta) - forecast_steps = math.ceil(len(ts_test) / stride) - - self.assertGreater(forecast_steps, 1, "sequence is not long enough") - - model = LSTMDetector( - LSTMDetectorConfig(max_forecast_steps=forecast_steps, nhid=256, threshold=AggregateAlarms(2, 1, 60, 300)) - ) - train_config = LSTMTrainConfig( - data_stride=stride, - epochs=1, - seq_len=forecast_steps * 2, - checkpoint_file=join(rootdir, "tmp", "lstm", "checkpoint.pt"), - ) - train_scores = model.train(train_data=ts_train, train_config=train_config) - - self.assertIsInstance( - train_scores, - TimeSeries, - msg="Expected output of train() to be a TimeSeries of anomaly " - "scores, but this seems to be a forecast. Check inheritance " - "order of this forecasting detector.", - ) - train_scores = train_scores.univariates[train_scores.names[0]] - train_vals = ts_train.univariates[ts_train.names[0]] - self.assertNotAlmostEqual( - train_scores.values[-1], - train_vals.values[-1], - delta=100, - msg="Expected output of train() to be a TimeSeries of anomaly " - "scores, but this seems to be a forecast. Check inheritance " - "order of this forecasting detector.", - ) - - ############## - scores = model.get_anomaly_score(ts_test) - logger.info("Scores look like:\n" + str(scores[:5])) - alarms = model.get_anomaly_label(ts_test) - logger.info("Alarms look like:\n" + str(alarms[:5])) - n_alarms = np.sum(alarms.to_pd().values != 0) - logger.info("# of alarms = " + str(n_alarms)) - self.assertLess(n_alarms, 20) - - ############## - # Note: we compare scores vs scoresv2[1:] because scoresv2 has one - # extra time step included. This is because when `time_series_prev` is - # given, we compute `self.model.transform(ts_train + ts_test)` and take - # the first time step in the transformed FULL time series which matches - # with `ts_test`. This is different from the first time step of - # `self.model.transform(ts_test)` due to the difference transform. - scoresv2 = model.get_anomaly_score(ts_test, ts_train)[1:] - self.assertLess(np.max((scores.to_pd() - scoresv2.to_pd()).abs().values), 0.1) - - ############## - model.save(join(rootdir, "tmp", "lstm")) - model = LSTMDetector.load(join(rootdir, "tmp", "lstm")) - loaded_scores = model.get_anomaly_score(ts_test) - self.assertSequenceEqual(list(scores), list(loaded_scores)) - loaded_alarms = model.get_anomaly_label(ts_test) - self.assertSequenceEqual(list(alarms), list(loaded_alarms)) - - -if __name__ == "__main__": - logging.basicConfig( - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.DEBUG - ) - unittest.main() diff --git a/tests/anomaly/forecast_based/test_sarima.py b/tests/anomaly/forecast_based/test_sarima.py index 26f7fd22a..4877d9e65 100644 --- a/tests/anomaly/forecast_based/test_sarima.py +++ b/tests/anomaly/forecast_based/test_sarima.py @@ -92,7 +92,8 @@ def test_full(self): n_alarms = np.sum(alarms.to_pd().values != 0) logger.info(f"Alarms look like:\n{alarms[:5]}") logger.info(f"Number of alarms: {n_alarms}\n") - self.assertEqual(n_alarms, 2) + self.assertLessEqual(n_alarms, 3) + self.assertGreaterEqual(n_alarms, 2) loaded_model_alarms = loaded_model.get_anomaly_label(self.vals_test) self.assertSequenceEqual(list(alarms), list(loaded_model_alarms)) diff --git a/tests/anomaly/multivariate/test_vae.py b/tests/anomaly/multivariate/test_vae.py index 857e8ffb7..b346acd79 100644 --- a/tests/anomaly/multivariate/test_vae.py +++ b/tests/anomaly/multivariate/test_vae.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause diff --git a/tests/forecast/test_autosarima.py b/tests/forecast/test_autosarima.py index a5ea8e116..3f06d23c6 100644 --- a/tests/forecast/test_autosarima.py +++ b/tests/forecast/test_autosarima.py @@ -847,12 +847,12 @@ def run_test(self, auto_pqPQ: bool, seasonality_layer: bool, expected_sMAPE: flo def test_autosarima(self): print("-" * 80) logger.info("TestAutoSarima.test_autosarima\n" + "-" * 80 + "\n") - self.run_test(auto_pqPQ=False, seasonality_layer=False, expected_sMAPE=3.806) + self.run_test(auto_pqPQ=False, seasonality_layer=False, expected_sMAPE=3.413) def test_seasonality_layer(self): print("-" * 80) logger.info("TestAutoSarima.test_seasonality_layer\n" + "-" * 80 + "\n") - self.run_test(auto_pqPQ=False, seasonality_layer=True, expected_sMAPE=3.806) + self.run_test(auto_pqPQ=False, seasonality_layer=True, expected_sMAPE=3.413) if __name__ == "__main__": diff --git a/tests/forecast/test_deep_model.py b/tests/forecast/test_deep_model.py new file mode 100644 index 000000000..9118dda21 --- /dev/null +++ b/tests/forecast/test_deep_model.py @@ -0,0 +1,261 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import logging +import os +import sys +import shutil +import unittest + +import gdown +import pandas as pd +from os.path import abspath, dirname, join, exists + +from merlion.evaluate.forecast import ForecastMetric +from merlion.models.forecast.autoformer import AutoformerConfig, AutoformerForecaster +from merlion.models.forecast.transformer import TransformerConfig, TransformerForecaster +from merlion.models.forecast.informer import InformerConfig, InformerForecaster +from merlion.models.forecast.etsformer import ETSformerConfig, ETSformerForecaster +from merlion.models.forecast.deep_ar import DeepARConfig, DeepARForecaster + + +from merlion.models.utils.rolling_window_dataset import RollingWindowDataset +from merlion.transform.bound import LowerUpperClip +from merlion.transform.normalize import MinMaxNormalize +from merlion.transform.resample import TemporalResample +from merlion.transform.sequence import TransformSequence +from merlion.utils.time_series import TimeSeries, to_pd_datetime +from ts_datasets.forecast import SeattleTrail +from ts_datasets.forecast.custom import CustomDataset + +logger = logging.getLogger(__name__) +rootdir = dirname(dirname(dirname(abspath(__file__)))) + + +class TestDeepModels(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.n_past = 16 + self.max_forecast_steps = 8 + self.early_stop_patience = 4 + self.num_epochs = 2 + self.use_gpu = True + self.batch_size = 32 + + df = self._obtain_df("weather") + bound = 16 * 20 + train_df = df[0:bound] + test_df = df[bound : 2 * bound] + + self.train_df = train_df + self.test_df = test_df + + self.train_data = TimeSeries.from_pd(self.train_df) + self.test_data = TimeSeries.from_pd(self.test_df) + + def test_deep_ar_predict_univariate(self): + print("-" * 80) + logger.info("test_deep_ar_predict_univariate\n" + "-" * 80) + self._test_deep_ar(20) + + def test_deep_ar_predict_multivariate(self): + print("-" * 80) + logger.info("test_deep_ar_predict_multivariate\n" + "-" * 80) + self._test_deep_ar(None) + + def test_autoformer_predict_univariate(self): + print("-" * 80) + logger.info("test_autoformer_predict_univariate\n" + "-" * 80) + self._test_autoformer(9) + + def test_autoformer_predict_multivariate(self): + print("-" * 80) + logger.info("test_autoformer_predict_multivariate\n" + "-" * 80) + self._test_autoformer(None) + + def test_informer_predict_univariate(self): + print("-" * 80) + logger.info("test_informer_predict_univariate\n" + "-" * 80) + self._test_informer(3) + + def test_informer_predict_multivariate(self): + print("-" * 80) + logger.info("test_informer_predict_multivariate\n" + "-" * 80) + self._test_informer(None) + + def test_etsformer_predict_univariate(self): + print("-" * 80) + logger.info("test_etsformer_predict_univariate\n" + "-" * 80) + self._test_etsformer(15) + + def test_etsformer_predict_multivariate(self): + print("-" * 80) + logger.info("test_etsformer_predict_multivariate\n" + "-" * 80) + self._test_etsformer(None) + + def test_transformer_predict_univariate(self): + print("-" * 80) + logger.info("test_transformer_predict_univariate\n" + "-" * 80) + self._test_transformer(0) + + def test_transformer_predict_multivariate(self): + print("-" * 80) + logger.info("test_transformer_predict_multivariate\n" + "-" * 80) + self._test_transformer(None) + + def _test_deep_ar(self, target_seq_index): + + logger.info("Testing Deep AR forecasting") + config = DeepARConfig( + n_past=self.n_past, + max_forecast_steps=self.max_forecast_steps, + early_stop_patience=self.early_stop_patience, + num_epochs=self.num_epochs, + use_gpu=self.use_gpu, + batch_size=self.batch_size, + target_seq_index=target_seq_index, + ) + + forecaster = DeepARForecaster(config) + + self._test_model(forecaster, self.train_data, self.test_data) + + def _test_autoformer(self, target_seq_index): + + logger.info("Testing Autoformer forecasting") + start_token_len = 3 + config = AutoformerConfig( + n_past=self.n_past, + max_forecast_steps=self.max_forecast_steps, + start_token_len=start_token_len, + early_stop_patience=self.early_stop_patience, + num_epochs=self.num_epochs, + use_gpu=self.use_gpu, + batch_size=self.batch_size, + target_seq_index=target_seq_index, + ) + + forecaster = AutoformerForecaster(config) + + self._test_model(forecaster, self.train_data, self.test_data) + + def _test_transformer(self, target_seq_index): + logger.info("Testing Transformer forecasting") + start_token_len = 3 + config = TransformerConfig( + n_past=self.n_past, + max_forecast_steps=self.max_forecast_steps, + start_token_len=start_token_len, + early_stop_patience=self.early_stop_patience, + num_epochs=self.num_epochs, + use_gpu=self.use_gpu, + batch_size=self.batch_size, + target_seq_index=target_seq_index, + ) + + forecaster = TransformerForecaster(config) + + self._test_model(forecaster, self.train_data, self.test_data) + + def _test_informer(self, target_seq_index): + logger.info("Testing Informer forecasting") + start_token_len = 3 + + config = InformerConfig( + n_past=self.n_past, + max_forecast_steps=self.max_forecast_steps, + start_token_len=start_token_len, + early_stop_patience=self.early_stop_patience, + num_epochs=self.num_epochs, + use_gpu=self.use_gpu, + batch_size=self.batch_size, + target_seq_index=target_seq_index, + ) + + forecaster = InformerForecaster(config) + + self._test_model(forecaster, self.train_data, self.test_data) + + def _test_etsformer(self, target_seq_index): + logger.info("Testing ETSformer forecasting") + + config = ETSformerConfig( + n_past=self.n_past, + max_forecast_steps=self.max_forecast_steps, + top_K=3, # top fourier basis + early_stop_patience=self.early_stop_patience, + num_epochs=self.num_epochs, + use_gpu=self.use_gpu, + batch_size=self.batch_size, + target_seq_index=target_seq_index, + ) + + forecaster = ETSformerForecaster(config) + + self._test_model(forecaster, self.train_data, self.test_data) + + def _obtain_df(self, dataset_name="weather"): + data_dir = join(rootdir, "data") + if dataset_name == "weather": + data_url = "https://drive.google.com/drive/folders/1Xz84ci5YKWL6O2I-58ZsVe42lYIfqui1" + data_folder = join(data_dir, "weather") + data_file_path = join(data_folder, "weather.csv") + else: + raise NotImplementedError + + if not exists(data_file_path): + while True: + try: + gdown.download_folder(data_url, quiet=False, use_cookies=False) + except TimeoutError: + logger.error("Timeout Error, try downloading again...") + else: + logger.info("Successfully downloaded %s!" % (dataset_name)) + break + + shutil.move("./%s" % (dataset_name), data_folder) + + weather_ds = CustomDataset(data_folder) + df, metadata = weather_ds[0] + + return df + + def _test_model(self, forecaster, train_data, test_data): + config = forecaster.config + model_name = forecaster.deep_model_class.__name__ + model_save_path = join("./models", model_name.lower()) + + logger.info(model_name) + + # training & saving + forecaster.train(train_data) + forecaster.save(model_save_path) + + # Single data forecasting testing + dataset = RollingWindowDataset( + test_data, + target_seq_index=config.target_seq_index, + n_past=config.n_past, + n_future=config.max_forecast_steps, + ts_index=True, + ) + test_prev, test = dataset[0] + forecaster.load(model_save_path) + pred, _ = forecaster.forecast(test.time_stamps, time_series_prev=test_prev) + assert pred.dim == 1 if forecaster.target_seq_index is not None else train_data.dim + + try: + shutil.rmtree(model_save_path) + except OSError as e: + logger.error(f"Error: {e.filename} - {e.strerror}.") + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", stream=sys.stdout, level=logging.INFO + ) + unittest.main() diff --git a/tests/forecast/test_forecast_ensemble.py b/tests/forecast/test_forecast_ensemble.py index 0b83386b9..f5409896e 100644 --- a/tests/forecast/test_forecast_ensemble.py +++ b/tests/forecast/test_forecast_ensemble.py @@ -106,7 +106,7 @@ def test_selector_small_train(self): logger.info("test_selector_small_train\n" + "-" * 80 + "\n") self.vals_train = self.vals_train[-8:] self.expected_smape = 194 - self._test_selector(test_name="test_selector_small_train", expected_smapes=[np.inf, 50.64, 6.16]) + self._test_selector(test_name="test_selector_small_train", expected_smapes=[np.inf, 7.27, 6.16]) def run_test(self, test_name): logger.info("Training model...") @@ -138,7 +138,7 @@ def run_test(self, test_name): y = self.vals_test.np_values smape = np.mean(200.0 * np.abs((y - yhat) / (np.abs(y) + np.abs(yhat)))) logger.info(f"sMAPE = {smape:.4f}") - self.assertAlmostEqual(smape, self.expected_smape, delta=2) + self.assertAlmostEqual(smape, self.expected_smape, delta=2 if self.expected_smape < 100 else 10) if __name__ == "__main__": diff --git a/ts_datasets/setup.py b/ts_datasets/setup.py index 71eacd070..48b99bb89 100644 --- a/ts_datasets/setup.py +++ b/ts_datasets/setup.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -16,5 +16,5 @@ long_description_content_type="text/markdown", license="Apache 2.0", packages=find_packages(include=["ts_datasets*"]), - install_requires=["cython", "numpy", "pandas", "requests", "sklearn", "tqdm", "wheel"], + install_requires=["cython", "numpy", "pandas", "requests", "tqdm", "wheel", "gdown"], ) diff --git a/ts_datasets/ts_datasets/forecast/m4.py b/ts_datasets/ts_datasets/forecast/m4.py index 79717781a..79904ac35 100644 --- a/ts_datasets/ts_datasets/forecast/m4.py +++ b/ts_datasets/ts_datasets/forecast/m4.py @@ -51,16 +51,8 @@ def __init__(self, subset="Hourly", rootdir=None): download(rootdir, self.url, "M4-info") # extract starting date from meta-information of dataset - info_dataset = pd.read_csv(os.path.join(rootdir, "M4-info.csv"), delimiter=",").set_index("M4id") - - if subset == "Yearly": - logger.warning( - "the max length of yearly data is 841 which is too big to convert to " - "timestamps, we fallback to quarterly frequency" - ) - self.freq = "Q" - else: - self.freq = subset[0] + self.freq = subset[0] + self.info_dataset = pd.read_csv(os.path.join(rootdir, "M4-info.csv"), parse_dates=True).set_index("M4id") train_csv = os.path.join(rootdir, f"train/{subset}-train.csv") if not os.path.isfile(train_csv): @@ -73,10 +65,19 @@ def __init__(self, subset="Hourly", rootdir=None): self.test_set = pd.read_csv(test_csv).set_index("V1") def __getitem__(self, i): - train, test = self.train_set.iloc[i].dropna(), self.test_set.iloc[i].dropna() + id = self.train_set.index[i] + train, test = self.train_set.loc[id].dropna(), self.test_set.loc[id].dropna() ts = pd.concat((train, test)).to_frame() # raw data do not follow consistent timestamp format - ts.index = pd.date_range(start=0, periods=ts.shape[0], freq=self.freq) + t0 = self.info_dataset.loc[id, "StartingDate"] + try: + ts.index = pd.date_range(start=t0, periods=len(ts), freq=self.freq) + except Exception as e: + if self.freq == "Y": + logger.warning(f"Time series {i} too long for yearly granularity. Using quarterly instead.") + ts.index = pd.date_range(start=t0, periods=len(ts), freq="Q") + else: + raise e md = pd.DataFrame({"trainval": ts.index < ts.index[len(train)]}, index=ts.index) return ts, md