This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecasting", Grigsby, Wang and Qi, 2021. (arXiv).
Spacetimeformer is a Transformer that learns temporal patterns like a time series model and spatial patterns like a Graph Neural Network.
June 2022 disclaimer: the updated implementation no longer matches the arXiv pre-prints. We are working on a new version of the paper. GitHub releases mark the paper versions.
Below we give a brief explanation of the problem and method with installation instructions. We provide training commands for high-performance results on several datasets.
We deal with multivariate sequence to sequence problems that have continuous inputs. The most common example is time series forecasting where we make predictions at future ("target") values given recent history ("context"):
Every model and dataset uses this x_context
, y_context
, x_target
, y_target
format. X values are time covariates like the calendar datetime, while Ys are variable values. There can be additional context variables that are not predicted.
Typical deep learning time series models group Y values by timestep and learn patterns across time. When using Transformer-based models, this results in "temporal" attention networks that can ignore spatial relationships between variables.
In contrast, Graph Neural Networks and similar methods model spatial relationships with explicit graphs - sharing information across space and time in alternating layers.
Spactimeformer learns full spatiotemporal patterns between all varibles at every timestep.
We implement spatiotemporal attention with a custom Transformer architecture and embedding that flattens multivariate sequences so that each token contains the value of a single variable at a given timestep:
Spacetimeformer processes these longer sequences with a mix of efficient attention mechanisms and Vision-style "windowed" attention.
This repo contains the code for our model as well as several high-quality baselines for common benchmarks and toy datasets.
This repository was written and tested for python 3.8 and pytorch 1.11.0.
git clone https://github.com/QData/spacetimeformer.git
cd spacetimeformer
conda create -n spacetimeformer python==3.8
source activate spacetimeformer
pip install -r requirements.txt
pip install -e .
This installs a python package called spacetimeformer
.
Commandline instructions for each experiment can be found using the format: python train.py *model* *dataset* -h
.
linear
: a basic autoregressive linear model. New June 2022: expanded to allow for seasonal decomposition and independent params for each variable (inspired by DLinear).lstnet
: a more typical RNN/Conv1D model for multivariate forecasting. Based on the attention-free implementation of LSTNet.lstm
: a typical encoder-decoder LSTM without attention. We use scheduled sampling to anneal teacher forcing throughout training.mtgnn
: a hybrid GNN that learns its graph structure from data. For more information refer to the paper. We use the implementation frompytorch_geometric_temporal
(requires some extra installation).s4
: long-sequence state-space model (paper) (requires some extra installation).heuristic
: simple heuristics like "repeat the last value in the context sequence" as a sanity-check.spacetimeformer
: the multivariate long-range transformer architecture discussed in our paper.- note that the "Temporal" ablation discussed in the paper is a special case of the
spacetimeformer
model. It is conceptually similar to Informer. Set theembed_method = temporal
. Spacetimeformer has many configurable options and we try to provide a thorough explanation with the commandline-h
instructions.
- note that the "Temporal" ablation discussed in the paper is a special case of the
metr-la
andpems-bay
: traffic forecasting datasets. We use a very similar setup to DCRNN.precip
: daily precipitation data from a latitude-longitude grid over the Continental United States.
toy2
: is the toy dataset mentioned at the beginning of our experiments section. It is heavily based on the toy dataset in TPA-LSTM.asos
: is the codebase's name for what the paper calls "NY-TX Weather."solar_energy
: Is the codebase's name for the time series benchmark more commonly called "AL Solar."exchange
: A common time series benchmark dataset of exchange rates.weather
: A common time series benchmark dataset of 21 weather indiciators.ettm1
: A common time series benchmark dataset of "electricity transformer temperatures" and related variables.
mnist
: Highlights the similarity between multivariate forecasting and vision models by completing the right side of an MNIST digit given the left side, where each row is a different variable.cifar
: A harder image completion task where the variables are color channels and the sequence is flattened across rows.
copy
: Copy binary input sequences with rows shifted by varying amounts. An example of a hard task for Temporal attention that is easy for Spatiotemporal attention.cont_copy
: A continuous version of the copy task with additional settings to study distribution shift.
-
m4
: The M4 competition dataset (overview). Collection of 100k univariate series at various resolutions. -
wiki
: The Wikipedia web traffic dataset from the Kaggle competition. 145k univariate high-entropy series at a single resolution. -
monash
: Loads the Monash Time Series Forecasting Archive. Up to ~400k time univariate timeseries.(We load these benchmarks in an unusual format where the context sequence is all data up until the current time - leading to variable length sequences with padding.)
We used wandb to track all of results during development, and you can do the same by providing your username and project as environment variables:
export STF_WANDB_ACCT="your_username"
export STF_WANDB_PROJ="your_project_title"
# optionally: change wandb logging directory (defaults to ./data/STF_LOG_DIR)
export STF_LOG_DIR="/somewhere/with/more/disk/space"
wandb logging can then be enabled with the --wandb
flag.
There are several figures that can be saved to wandb between epochs. These vary by dataset but can be enabled with --attn_plot
(for Transformer attention diagrams) and --plot
(for prediction plotting and image completion).
-
Commands are listed without GPU counts. For one GPU, add
--gpus 0
, three GPUs:--gpus 0 1 2
etc. Some of these models require significant GPU memory (A100 80GBs). Other hyperparameter settings were used in older versions of the paper with more limited compute resources. If I have time I will try to update with competetive alternatives on smaller GPUs. -
Some datasets require a
--data_path
to the dataset location on disk. Others are included with the source code or downloaded automatically.
Linear autoregressive model with independent weights and seasonal decomposotion (DLinear-style) on ETTm1:
python train.py linear ettm1 --context_points 288 --target_points 96 --run_name linear_ettm1_regression --gpus 0 --use_seasonal_decomp --linear_window 288 --data_path /path/to/ETTm1.csv
Spacetimeformer on Pems-Bay (MAE: ~1.61):
python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path /path/to/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8
Spacetimeformer on MNIST completion:
python train.py spacetimeformer mnist --embed_method spatio-temporal --local_self_attn full --local_cross_attn full --global_self_attn full --global_cross_attn full --run_name mnist_spatiotemporal --context_points 10
Spacetimeformer on AL Solar (MSE: ~7.75):
python train.py spacetimeformer solar_energy --context_points 168 --target_points 24 --d_model 100 --d_ff 400 --enc_layers 5 --dec_layers 5 --l2_coeff 1e-3 --dropout_ff .2 --dropout_emb .1 --d_qk 20 --d_v 20 --n_heads 6 --run_name spatiotemporal_al_solar --batch_size 32 --class_loss_imp 0 --initial_downsample_convs 1 --decay_factor .8 --warmup_steps 1000
More Coming Soon...
If you use this model in academic work please feel free to cite our paper
@misc{grigsby2021longrange,
title={Long-Range Transformers for Dynamic Spatiotemporal Forecasting},
author={Jake Grigsby and Zhe Wang and Yanjun Qi},
year={2021},
eprint={2109.12218},
archivePrefix={arXiv},
primaryClass={cs.LG}
}