diff --git a/bird_cloud_gnn/cross_validation.py b/bird_cloud_gnn/cross_validation.py index b7bf797..6bdc7c9 100644 --- a/bird_cloud_gnn/cross_validation.py +++ b/bird_cloud_gnn/cross_validation.py @@ -32,7 +32,7 @@ def get_dataloaders(dataset, train_idx, test_idx, batch_size): # pylint: disable=too-many-arguments, too-many-locals def kfold_evaluate( dataset, - h_feats=16, + layers_data, n_splits=5, learning_rate=0.01, num_epochs=100, @@ -43,7 +43,7 @@ def kfold_evaluate( Args: dataset (RadarDataset): The dataset - h_feats (int, optional): The number of hidden features of the model + layers_data (list): The list of input size and activation n_splits (int, optional): Number of folds. Defaults to 5. learning_rate (float, optional): Learning rate. Defaults to 0.01. num_epochs (int, optional): Training epochs. Defaults to 20. @@ -62,8 +62,7 @@ def kfold_evaluate( model = GCN( in_feats=len(dataset.features), - h_feats=h_feats, - num_classes=2, + layers_data=layers_data, ) model.fit(train_dataloader, learning_rate=learning_rate, num_epochs=num_epochs) @@ -76,7 +75,7 @@ def kfold_evaluate( def leave_one_origin_out_evaluate( dataset, - h_feats=16, + layers_data, learning_rate=0.01, num_epochs=100, batch_size=512, @@ -88,7 +87,7 @@ def leave_one_origin_out_evaluate( Args: dataset (RadarDataset): The dataset. - h_feats (int, optional): The number of hidden features of the model + layers_data (list): The list of input size and activation n_splits (int, optional): Number of folds. Defaults to 5. learning_rate (float, optional): Learning rate. Defaults to 0.01. num_epochs (int, optional): Training epochs. Defaults to 20. @@ -110,8 +109,7 @@ def leave_one_origin_out_evaluate( model = GCN( in_feats=len(dataset.features), - h_feats=h_feats, - num_classes=2, + layers_data=layers_data, ) model.fit(train_dataloader, learning_rate=learning_rate, num_epochs=num_epochs) diff --git a/bird_cloud_gnn/gnn_model.py b/bird_cloud_gnn/gnn_model.py index e25ab2b..4698139 100644 --- a/bird_cloud_gnn/gnn_model.py +++ b/bird_cloud_gnn/gnn_model.py @@ -3,11 +3,11 @@ import os import dgl import numpy as np -import torch.nn.functional as F from dgl.dataloading import GraphDataLoader from dgl.nn.pytorch.conv import GraphConv from torch import nn from torch import optim +from torch.nn.modules import Module from tqdm import tqdm @@ -17,11 +17,11 @@ class GCN(nn.Module): """Graph Convolutional Network construction module - A two-layer GCN is constructed from input dimension, hidden dimensions and number of classes. - Each layer computes new node representations by aggregating neighbor information. + A n-layer GCN is constructed from input features and list of layers + Each layer computes new node representations by aggregating neighbour information. """ - def __init__(self, in_feats: int, h_feats: int, num_classes: int): + def __init__(self, in_feats: int, layers_data: list): """ The __init__ function is the constructor for a class. It is called when an object of that class is instantiated. It can have multiple arguments and it will always be called before __new__(). @@ -30,32 +30,32 @@ def __init__(self, in_feats: int, h_feats: int, num_classes: int): Args: self: Access variables that belongs to the class object in_feats: the number of input features - h_feats: the number of hidden features that we want to use for our first graph convolutional layer - num_classes: the number of classes that we want to predict + layers_data: is a list of tuples of size of hidden layer and activation function Returns: The self object """ super().__init__() self.in_feats = in_feats - self.h_feats = h_feats - self.num_classes = num_classes - self.conv1 = GraphConv(in_feats, h_feats) - self.conv2 = GraphConv(h_feats, num_classes) + self.layers = nn.ModuleList() + self.name = "" + for size, activation in layers_data: + self.layers.append(GraphConv(in_feats, size)) + self.name = self.name + f"{in_feats}-{size}_" + in_feats = size # For the next layer + if activation is not None: + assert isinstance( + activation, Module + ), "Each tuples should contain a size (int) and a torch.nn.modules.Module." + self.layers.append(activation) + self.name = self.name + repr(activation).split("(", 1)[0] + "_" + self.num_classes = size # the last size should correspond to the number of classes were predicting def oneline_description(self): """Description of the model to uniquely identify it in logs""" - return "-".join( - [ - "in", - f"GC_{self.h_feats}", - "RELU", - f"GC_{self.num_classes}", - "mean-out", - ] - ) + return "-".join(["in_", f"{self.name}", "mean-out"]) - def forward(self, g, in_feat): + def forward(self, g, in_feats): """ The forward function computes the output of the model. @@ -67,10 +67,13 @@ def forward(self, g, in_feat): Returns: The output of the second convolutional layer """ - h = self.conv1(g, in_feat) - h = F.relu(h) - h = self.conv2(g, h) - g.ndata["h"] = h + for layer in self.layers: + if isinstance(layer, (nn.ReLU, nn.LeakyReLU, nn.ELU)): + in_feats = layer(in_feats) + else: + in_feats = layer(g, in_feats) + + g.ndata["h"] = in_feats return dgl.mean_nodes(g, "h") def fit(self, train_dataloader, learning_rate=0.01, num_epochs=20): @@ -220,10 +223,13 @@ def fit_and_evaluate( epoch_values["Loss/test"] = test_loss epoch_values["Accuracy/test"] = num_correct / num_total - epoch_values["Layer/conv1"] = self.conv1.weight.detach() - epoch_values["Layer/conv2"] = self.conv2.weight.detach() + for i, pg in enumerate(optimizer.param_groups): epoch_values[f"LearningRate/ParGrp{i}"] = pg["lr"] + # to visualise distribution of tensors + for i, layer in enumerate(self.layers): + if not isinstance(layer, (nn.ReLU, nn.LeakyReLU, nn.ELU)): + epoch_values[f"Layer/conv{i}"] = layer.weight.detach() if self.num_classes == 2: epoch_values["FalseNegativeRate/test"] = num_false_negative / num_total epoch_values["FalsePositiveRate/test"] = num_false_positive / num_total @@ -251,10 +257,7 @@ def infer(self, dataset, batch_size=1024): """ self.eval() dataloader = GraphDataLoader( - shuffle=False, - dataset=dataset, - batch_size=batch_size, - drop_last=False, + shuffle=False, dataset=dataset, batch_size=batch_size, drop_last=False ) labels = np.array([]) for batched_graph, _ in dataloader: diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 6062338..954e664 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -1,4 +1,5 @@ """Tests for cross_validation""" +from torch import nn from bird_cloud_gnn.cross_validation import kfold_evaluate from bird_cloud_gnn.cross_validation import leave_one_origin_out_evaluate @@ -8,7 +9,7 @@ def test_kfold_evaluate(dataset_fixture): kfold_evaluate( dataset_fixture, - h_feats=32, + layers_data=[(32, nn.ReLU()), (2, None)], ) @@ -17,5 +18,5 @@ def test_leave_one_out_evaluate(dataset_fixture): leave_one_origin_out_evaluate( dataset_fixture, - h_feats=32, + layers_data=[(32, nn.ReLU()), (2, None)], ) diff --git a/tests/test_gnn_model.py b/tests/test_gnn_model.py index b3509c0..aa95440 100644 --- a/tests/test_gnn_model.py +++ b/tests/test_gnn_model.py @@ -1,6 +1,7 @@ """Tests for gnn_model module""" import torch from dgl.dataloading import GraphDataLoader +from torch import nn from torch.utils.data.sampler import SubsetRandomSampler from bird_cloud_gnn.callback import CombinedCallback from bird_cloud_gnn.callback import EarlyStopperCallback @@ -30,7 +31,7 @@ def test_gnn_model(dataset_fixture): drop_last=False, ) - model = GCN(len(dataset_fixture.features), 16, 2) + model = GCN(len(dataset_fixture.features), [(16, nn.ReLU()), (2, None)]) model.fit(train_dataloader) model.evaluate(test_dataloader) @@ -53,13 +54,20 @@ class TestBasicBehaviour: def test_field_access(self): """Test field access""" - model = GCN(in_feats=10, h_feats=16, num_classes=2) + model = GCN(in_feats=10, layers_data=[(16, nn.ReLU()), (2, None)]) assert model.in_feats == 10 - assert model.h_feats == 16 + assert model.name == "10-16_ReLU_16-2_" assert model.num_classes == 2 def test_inequality(self): """Test inequality of created GCN classes""" - model1 = GCN(in_feats=10, h_feats=16, num_classes=2) - model2 = GCN(in_feats=15, h_feats=16, num_classes=5) + model1 = GCN(in_feats=10, layers_data=[(16, nn.ReLU()), (2, None)]) + model2 = GCN(in_feats=15, layers_data=[(16, nn.ReLU()), (2, None)]) assert model1 != model2 + + def test_inequality_activation(self): + """Test inequality of created GCN classes with different activation""" + model1 = GCN(in_feats=10, layers_data=[(16, nn.ReLU()), (2, None)]) + model2 = GCN(in_feats=10, layers_data=[(16, nn.ELU()), (2, None)]) + assert model1 != model2 + assert model2.name == "10-16_ELU_16-2_"