Skip to content

Commit

Permalink
mstgcncommon (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
aptx1231 authored Aug 1, 2021
1 parent 03fa1c9 commit 69a98ee
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 2 deletions.
22 changes: 22 additions & 0 deletions libtraffic/config/model/traffic_state_pred/MSTGCNCommon.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"nb_block": 2,
"K": 3,
"nb_chev_filter": 64,
"nb_time_filter": 64,

"scaler": "standard",
"load_external": false,
"normal_external": false,
"ext_scaler": "none",
"add_time_in_day": false,
"add_day_in_week": false,
"train_rate": 0.6,
"eval_rate": 0.2,

"max_epoch": 100,
"learner": "adam",
"learning_rate": 0.0001,
"lr_decay": false,
"clip_grad_norm": false,
"use_early_stop": false
}
9 changes: 7 additions & 2 deletions libtraffic/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"traffic_state_pred": {
"allowed_model": ["DCRNN", "STGCN", "GWNET", "AGCRN", "TGCLSTM", "TGCN", "TemplateTSP",
"ASTGCN", "MSTGCN", "MTGNN", "ACFM", "STResNet", "RNN", "LSTM", "GRU", "AutoEncoder", "Seq2Seq",
"STResNetCommon", "ACFMCommon", "ASTGCNCommon", "ToGCN", "CONVGCN", "STG2Seq", "DMVSTNet",
"ATDM", "GMAN", "GTS", "STDN", "HGCN", "STSGCN", "STAGGCN", "STNN", "ResLSTM", "DGCN",
"STResNetCommon", "ACFMCommon", "ASTGCNCommon", "MSTGCNCommon","ToGCN", "CONVGCN", "STG2Seq",
"DMVSTNet", "ATDM", "GMAN", "GTS", "STDN", "HGCN", "STSGCN", "STAGGCN", "STNN", "ResLSTM", "DGCN",
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN"],
"allowed_dataset": ["METR_LA", "PEMS_BAY", "PEMSD3", "PEMSD4", "PEMSD7", "PEMSD8", "PEMSD7(M)",
"LOOP_SEATTLE", "LOS_LOOP", "LOS_LOOP_SMALL", "Q_TRAFFIC", "SZ_TAXI",
Expand Down Expand Up @@ -196,6 +196,11 @@
"executor": "TrafficStateExecutor",
"evaluator": "TrafficStateEvaluator"
},
"MSTGCNCommon": {
"dataset_class": "TrafficStatePointDataset",
"executor": "TrafficStateExecutor",
"evaluator": "TrafficStateEvaluator"
},
"ToGCN": {
"dataset_class": "TrafficStatePointDataset",
"executor": "TrafficStateExecutor",
Expand Down
215 changes: 215 additions & 0 deletions libtraffic/model/traffic_flow_prediction/MSTGCNCommon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from logging import getLogger
from libtraffic.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libtraffic.model import loss
from scipy.sparse.linalg import eigs


def scaled_laplacian(weight):
"""
compute \tilde{L} (scaled laplacian matrix)
Args:
weight(np.ndarray): shape is (N, N), N is the num of vertices
Returns:
np.ndarray: shape (N, N)
"""
assert weight.shape[0] == weight.shape[1]
diag = np.diag(np.sum(weight, axis=1))
lap = diag - weight
lambda_max = eigs(lap, k=1, which='LR')[0].real
return (2 * lap) / lambda_max - np.identity(weight.shape[0])


def cheb_polynomial(l_tilde, k):
"""
compute a list of chebyshev polynomials from T_0 to T_{K-1}
Args:
l_tilde(np.ndarray): scaled Laplacian, shape (N, N)
k(int): the maximum order of chebyshev polynomials
Returns:
list(np.ndarray): cheb_polynomials, length: K, from T_0 to T_{K-1}
"""
num = l_tilde.shape[0]
cheb_polynomials = [np.identity(num), l_tilde.copy()]
for i in range(2, k):
cheb_polynomials.append(2 * l_tilde * cheb_polynomials[i - 1] - cheb_polynomials[i - 2])
return cheb_polynomials


class ChebConv(nn.Module):
"""
K-order chebyshev graph convolution
"""
def __init__(self, k, cheb_polynomials, in_channels, out_channels):
"""
Args:
k(int):
cheb_polynomials:
in_channels(int): num of channels in the input sequence
out_channels(int): num of channels in the output sequence
"""
super(ChebConv, self).__init__()
self.K = k
self.cheb_polynomials = cheb_polynomials
self.in_channels = in_channels
self.out_channels = out_channels
self.DEVICE = cheb_polynomials[0].device
self.Theta = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channels, out_channels)
.to(self.DEVICE)) for _ in range(k)])

def forward(self, x):
"""
Chebyshev graph convolution operation
Args:
x: (batch_size, N, F_in, T)
Returns:
torch.tensor: (batch_size, N, F_out, T)
"""
batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape

outputs = []

for time_step in range(num_of_timesteps):

graph_signal = x[:, :, :, time_step] # (b, N, F_in)

output = torch.zeros(batch_size, num_of_vertices, self.out_channels).to(self.DEVICE) # (b, N, F_out)

for k in range(self.K):

t_k = self.cheb_polynomials[k] # (N,N)

theta_k = self.Theta[k] # (in_channel, out_channel)

rhs = graph_signal.permute(0, 2, 1).matmul(t_k).permute(0, 2, 1)

output = output + rhs.matmul(theta_k)

outputs.append(output.unsqueeze(-1))

return F.relu(torch.cat(outputs, dim=-1))


class MSTGCNBlock(nn.Module):
def __init__(self, in_channels, k, nb_chev_filter, nb_time_filter, time_strides, cheb_polynomials):
super(MSTGCNBlock, self).__init__()
self.ChebConv = ChebConv(k, cheb_polynomials, in_channels, nb_chev_filter)
self.time_conv = nn.Conv2d(nb_chev_filter, nb_time_filter, kernel_size=(1, 3),
stride=(1, time_strides), padding=(0, 1))
self.residual_conv = nn.Conv2d(in_channels, nb_time_filter, kernel_size=(1, 1), stride=(1, time_strides))
self.ln = nn.LayerNorm(nb_time_filter)

def forward(self, x):
"""
Args:
x: (batch_size, N, F_in, T)
Returns:
torch.tensor: (batch_size, N, nb_time_filter, T)
"""
# cheb gcn
spatial_gcn = self.ChebConv(x) # (b,N,F,T)

# convolution along the time axis
time_conv_output = self.time_conv(spatial_gcn.permute(0, 2, 1, 3)) # (b,F,N,T)

# residual shortcut
x_residual = self.residual_conv(x.permute(0, 2, 1, 3)) # (b,F,N,T)

x_residual = self.ln(F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1)).permute(0, 2, 3, 1) # (b,N,F,T)

return x_residual


class MSTGCNSubmodule(nn.Module):
def __init__(self, device, nb_block, in_channels, k, nb_chev_filter, nb_time_filter,
input_window, cheb_polynomials, output_window, output_dim, num_of_vertices):
super(MSTGCNSubmodule, self).__init__()

self.BlockList = nn.ModuleList([
MSTGCNBlock(in_channels, k, nb_chev_filter, nb_time_filter,
input_window // output_window, cheb_polynomials)])

self.BlockList.extend([
MSTGCNBlock(nb_time_filter, k, nb_chev_filter, nb_time_filter, 1, cheb_polynomials)
for _ in range(nb_block-1)])

self.final_conv = nn.Conv2d(output_window, output_window,
kernel_size=(1, nb_time_filter - output_dim + 1))

def forward(self, x):
"""
Args:
x: (B, T_in, N_nodes, F_in)
Returns:
torch.tensor: (B, T_out, N_nodes, out_dim)
"""
x = x.permute(0, 2, 3, 1) # (B, N, F_in(feature_dim), T_in)
for block in self.BlockList:
x = block(x)
output = self.final_conv(x.permute(0, 3, 1, 2))
return output


# 适配最一般的TrafficStateGridDataset和TrafficStatePointDataset
class MSTGCNCommon(AbstractTrafficStateModel):
def __init__(self, config, data_feature):
super().__init__(config, data_feature)

self.num_nodes = self.data_feature.get('num_nodes', 1)
self.feature_dim = self.data_feature.get('feature_dim', 1)
self.output_dim = self.data_feature.get('output_dim', 1)

self.input_window = config.get('input_window', 1)
self.output_window = config.get('output_window', 1)
self.device = config.get('device', torch.device('cpu'))
self.nb_block = config.get('nb_block', 2)
self.K = config.get('K', 3)
self.nb_chev_filter = config.get('nb_chev_filter', 64)
self.nb_time_filter = config.get('nb_time_filter', 64)

adj_mx = self.data_feature.get('adj_mx')
l_tilde = scaled_laplacian(adj_mx)
self.cheb_polynomials = [torch.from_numpy(i).type(torch.FloatTensor).to(self.device)
for i in cheb_polynomial(l_tilde, self.K)]
self._logger = getLogger()
self._scaler = self.data_feature.get('scaler')

self.MSTGCN_submodule = \
MSTGCNSubmodule(self.device, self.nb_block, self.feature_dim,
self.K, self.nb_chev_filter, self.nb_time_filter,
self.input_window, self.cheb_polynomials,
self.output_window, self.output_dim, self.num_nodes)
self._init_parameters()

def _init_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)

def forward(self, batch):
x = batch['X'].to(self.device) # (B, T, N_nodes, F_in)
output = self.MSTGCN_submodule(x)
return output # (B, T', N_nodes, F_out)

def calculate_loss(self, batch):
y_true = batch['y']
y_predicted = self.predict(batch)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
return loss.masked_mse_torch(y_predicted, y_true)

def predict(self, batch):
return self.forward(batch)
2 changes: 2 additions & 0 deletions libtraffic/model/traffic_flow_prediction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from libtraffic.model.traffic_flow_prediction.STResNetCommon import STResNetCommon
from libtraffic.model.traffic_flow_prediction.ACFMCommon import ACFMCommon
from libtraffic.model.traffic_flow_prediction.ASTGCNCommon import ASTGCNCommon
from libtraffic.model.traffic_flow_prediction.MSTGCNCommon import MSTGCNCommon
from libtraffic.model.traffic_flow_prediction.ToGCN import ToGCN
from libtraffic.model.traffic_flow_prediction.CONVGCN import CONVGCN
from libtraffic.model.traffic_flow_prediction.STDN import STDN
Expand All @@ -27,6 +28,7 @@
"STResNetCommon",
"ACFMCommon",
"ASTGCNCommon",
"MSTGCNCommon",
"ToGCN",
"CONVGCN",
"STDN",
Expand Down

0 comments on commit 69a98ee

Please sign in to comment.