-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #112 from point-cloud-radar/implement-plotting-of-…
…iterations
- Loading branch information
Showing
5 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from torch.utils.tensorboard import SummaryWriter | ||
from bird_cloud_gnn.early_stopper import EarlyStopper | ||
|
||
|
||
class TensorboardCallback: | ||
"""Callback to populate Tensorboard""" | ||
|
||
def __init__(self): | ||
self.writer = SummaryWriter() | ||
|
||
def __call__(self, epoch_values): | ||
epoch = epoch_values["epoch"] | ||
for field in ["Loss/train", "Loss/test", "Accuracy/train", "Accuracy/test"]: | ||
self.writer.add_scalar(field, epoch_values[field], epoch) | ||
return False | ||
|
||
|
||
class EarlyStopperCallback: | ||
"""Callback to check early stopping.""" | ||
|
||
def __init__(self, **kwargs): | ||
"""Input arguments are passed to EarlyStopper.""" | ||
self.early_stopper = EarlyStopper(**kwargs) | ||
|
||
def __call__(self, epoch_values): | ||
return self.early_stopper.early_stop(epoch_values["Loss/test"]) | ||
|
||
|
||
class CombinedCallback: | ||
"""Helper to combine multiple callbacks.""" | ||
|
||
def __init__(self, callbacks): | ||
""" | ||
Args: | ||
callbacks (iterable): List of callbacks. These are called in the given sequence and | ||
if one of them returns True, the subsequents are not called. | ||
""" | ||
self.callbacks = callbacks | ||
|
||
def __call__(self, epoch_values): | ||
return_value = False | ||
for callback in self.callbacks: | ||
return_value = return_value or callback(epoch_values) | ||
return return_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Module for early stopping class | ||
""" | ||
import numpy as np | ||
|
||
|
||
class EarlyStopper: | ||
"""Early stopper check.""" | ||
|
||
def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0): | ||
"""EarlyStopper. Use to stop if the validation loss starts increasing. | ||
The validation loss is increasing if | ||
L > Lmin + abs_delta + rel_delta * |Lmin|, | ||
where `L` is the current validation loss, `Lmin` is the minimum validation loss found so | ||
far, and `abs_delta` and `rel_delta` are absolute and relative tolerances to the increase, | ||
respectively. | ||
Args: | ||
patience (int, optional): How many consecutive iterations to wait before stopping. | ||
Defaults to 3. | ||
min_abs_delta (float, optional): Absolute tolerance to the increase. Defaults to 1e-2. | ||
min_rel_delta (float, optional): Relative tolerance to the increase. Defaults to 0.0. | ||
""" | ||
self.patience = patience | ||
self.min_abs_delta = min_abs_delta | ||
self.min_rel_delta = min_rel_delta | ||
self.counter = 0 | ||
self.min_validation_loss = np.inf | ||
|
||
def early_stop(self, validation_loss): | ||
"""Check whether it is time to stop, and update the internal of EarlyStopper. | ||
Args: | ||
validation_loss (float): Current validation loss | ||
Returns: | ||
stop (boolean): Whether it is time to stop (True) or not (False). | ||
""" | ||
|
||
if validation_loss < self.min_validation_loss: | ||
self.min_validation_loss = validation_loss | ||
self.counter = 0 | ||
return False | ||
|
||
if self.min_validation_loss is np.inf: | ||
return False | ||
|
||
loss_threshold = ( | ||
self.min_validation_loss | ||
+ self.min_abs_delta | ||
+ self.min_rel_delta * np.abs(self.min_validation_loss) | ||
) | ||
|
||
if validation_loss > loss_threshold: | ||
self.counter += 1 | ||
if self.counter >= self.patience: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters