Skip to content

Commit

Permalink
Merge pull request #112 from point-cloud-radar/implement-plotting-of-…
Browse files Browse the repository at this point in the history
…iterations
  • Loading branch information
lyashevska authored Jul 3, 2023
2 parents d9a6758 + ec42087 commit cc78cce
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 0 deletions.
44 changes: 44 additions & 0 deletions bird_cloud_gnn/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch.utils.tensorboard import SummaryWriter
from bird_cloud_gnn.early_stopper import EarlyStopper


class TensorboardCallback:
"""Callback to populate Tensorboard"""

def __init__(self):
self.writer = SummaryWriter()

def __call__(self, epoch_values):
epoch = epoch_values["epoch"]
for field in ["Loss/train", "Loss/test", "Accuracy/train", "Accuracy/test"]:
self.writer.add_scalar(field, epoch_values[field], epoch)
return False


class EarlyStopperCallback:
"""Callback to check early stopping."""

def __init__(self, **kwargs):
"""Input arguments are passed to EarlyStopper."""
self.early_stopper = EarlyStopper(**kwargs)

def __call__(self, epoch_values):
return self.early_stopper.early_stop(epoch_values["Loss/test"])


class CombinedCallback:
"""Helper to combine multiple callbacks."""

def __init__(self, callbacks):
"""
Args:
callbacks (iterable): List of callbacks. These are called in the given sequence and
if one of them returns True, the subsequents are not called.
"""
self.callbacks = callbacks

def __call__(self, epoch_values):
return_value = False
for callback in self.callbacks:
return_value = return_value or callback(epoch_values)
return return_value
60 changes: 60 additions & 0 deletions bird_cloud_gnn/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Module for early stopping class
"""
import numpy as np


class EarlyStopper:
"""Early stopper check."""

def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0):
"""EarlyStopper. Use to stop if the validation loss starts increasing.
The validation loss is increasing if
L > Lmin + abs_delta + rel_delta * |Lmin|,
where `L` is the current validation loss, `Lmin` is the minimum validation loss found so
far, and `abs_delta` and `rel_delta` are absolute and relative tolerances to the increase,
respectively.
Args:
patience (int, optional): How many consecutive iterations to wait before stopping.
Defaults to 3.
min_abs_delta (float, optional): Absolute tolerance to the increase. Defaults to 1e-2.
min_rel_delta (float, optional): Relative tolerance to the increase. Defaults to 0.0.
"""
self.patience = patience
self.min_abs_delta = min_abs_delta
self.min_rel_delta = min_rel_delta
self.counter = 0
self.min_validation_loss = np.inf

def early_stop(self, validation_loss):
"""Check whether it is time to stop, and update the internal of EarlyStopper.
Args:
validation_loss (float): Current validation loss
Returns:
stop (boolean): Whether it is time to stop (True) or not (False).
"""

if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
return False

if self.min_validation_loss is np.inf:
return False

loss_threshold = (
self.min_validation_loss
+ self.min_abs_delta
+ self.min_rel_delta * np.abs(self.min_validation_loss)
)

if validation_loss > loss_threshold:
self.counter += 1
if self.counter >= self.patience:
return True
return False
68 changes: 68 additions & 0 deletions bird_cloud_gnn/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dgl.nn.pytorch.conv import GraphConv
from torch import nn
from torch import optim
from tqdm import tqdm


os.environ["DGLBACKEND"] = "pytorch"
Expand Down Expand Up @@ -102,6 +103,73 @@ def evaluate(self, test_dataloader):
accuracy = num_correct / num_tests
return accuracy

# pylint: disable=too-many-arguments
def fit_and_evaluate(
self,
train_dataloader,
test_dataloader,
callback=None,
learning_rate=0.01,
num_epochs=20,
):
"""Fit the model while evaluating every iteraction.
Args:
train_dataloader (RandomWSubsetSampler): Data loader to train set.
test_dataloader (RandomWSubsetSampler): Data loader to test set.
callback (callable, optional): Callback function. If defined, should receive a dict
that stores "Loss/train", "Accuracy/train", "Loss/test", "Accuracy/test", and
"epoch" of a single epoch. To send a stop signal, return True.
Defaults to None.
learning_rate (float, optional): Learning rate. Defaults to 0.01.
num_epochs (int, optional): Number of training epochs. Defaults to 20.
"""
progress_bar = tqdm(total=num_epochs)
optimizer = optim.Adam(self.parameters(), lr=learning_rate)
epoch_values = {}
for epoch in range(num_epochs):
epoch_values["epoch"] = epoch
train_loss = 0.0
num_correct = 0
num_total = 0
self.train()
for batched_graph, labels in train_dataloader:
pred = self(batched_graph, batched_graph.ndata["x"].float())
loss = nn.functional.cross_entropy(pred, labels)

train_loss += loss.item()
num_correct += (pred.argmax(1) == labels).sum().item()
num_total += len(labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_values["Loss/train"] = train_loss
epoch_values["Accuracy/train"] = num_correct / num_total

test_loss = 0.0
num_correct = 0
num_total = 0
self.eval()
for batched_graph, labels in test_dataloader:
pred = self(batched_graph, batched_graph.ndata["x"].float())

test_loss += nn.functional.cross_entropy(pred, labels).item()
num_correct += (pred.argmax(1) == labels).sum().item()
num_total += len(labels)

epoch_values["Loss/test"] = test_loss
epoch_values["Accuracy/test"] = num_correct / num_total

progress_bar.set_postfix({"Epoch": epoch})
progress_bar.update(1)

if callback is not None:
user_request_stop = callback(epoch_values)
if user_request_stop is True: # Check for explicit True
break

def infer(self, dataset, batch_size=1024):
"""
Using the model do inference on a dataset.
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
pandas
scipy
torch
tensorboard

[options.data_files]
# This section requires setuptools>=40.6.0
Expand Down
11 changes: 11 additions & 0 deletions tests/test_gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from bird_cloud_gnn.callback import CombinedCallback
from bird_cloud_gnn.callback import EarlyStopperCallback
from bird_cloud_gnn.callback import TensorboardCallback
from bird_cloud_gnn.gnn_model import GCN


Expand Down Expand Up @@ -31,6 +34,14 @@ def test_gnn_model(dataset_fixture):
model.fit(train_dataloader)
model.evaluate(test_dataloader)

callback = callback = CombinedCallback(
[
TensorboardCallback(),
EarlyStopperCallback(patience=3),
]
)
model.fit_and_evaluate(train_dataloader, test_dataloader, callback)

assert len(model.infer(dataset_fixture, batch_size=30)) == len(dataset_fixture)
assert (
(model.infer(dataset_fixture) == 1) | (model.infer(dataset_fixture) == 0)
Expand Down

0 comments on commit cc78cce

Please sign in to comment.