Skip to content

Commit

Permalink
Merge branch 'main' into dev-aimnet2-new
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm authored Nov 7, 2024
2 parents 65d2b43 + fa4814f commit 9115792
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,8 @@ def training_step(
if self.training_parameter.log_norm:
if key == "total_loss":
continue # Skip total loss for gradient norm logging

grad_norm = compute_grad_norm(metric.mean(), self)
log.info(f"grad_norm/{key}: {grad_norm}")
self.log(f"grad_norm/{key}", grad_norm)
self.log(f"grad_norm/{key}", grad_norm, sync_dist=True)

# Save energy predictions and targets
self._update_predictions(
Expand Down Expand Up @@ -1159,7 +1157,9 @@ def _log_figures_for_each_phase(
# Log outlier error counts for non-training phases
if phase != "train":
self._identify__and_log_top_k_errors(errors, gathered_indices, phase)
self.log_dict(self.outlier_errors_over_epochs, on_epoch=True)
self.log_dict(
self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True
)

def _identify__and_log_top_k_errors(
self,
Expand Down Expand Up @@ -1199,7 +1199,9 @@ def _identify__and_log_top_k_errors(
if key not in self.outlier_errors_over_epochs:
self.outlier_errors_over_epochs[key] = 0
self.outlier_errors_over_epochs[key] += 1
log.info(f"{phase} : Outlier error {error} at index {idx}.")
log.info(
f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}."
)

def _clear_error_tracking(self, preds, targets, incides):
"""
Expand Down Expand Up @@ -1252,19 +1254,6 @@ def on_validation_epoch_end(self):
self.val_indices,
)

def on_train_start(self):
"""Log the GPU name to Weights & Biases at the start of training."""
if isinstance(self.logger, pL.loggers.WandbLogger) and self.global_rank == 0:
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
else:
gpu_name = "CPU"
# Log GPU name to W&B
self.logger.experiment.config.update({"GPU": gpu_name})
self.logger.experiment.log({"GPU Name": gpu_name})
else:
log.warning("Weights & Biases logger not found; GPU name not logged.")

def on_train_epoch_start(self):
"""Start the epoch timer."""
self.epoch_start_time = time.time()
Expand All @@ -1282,15 +1271,7 @@ def _log_time(self):

def on_train_epoch_end(self):
"""Logs metrics, learning rate, histograms, and figures at the end of the training epoch."""
if self.global_rank == 0:
self._log_metrics(self.loss_metrics, "loss")
self._log_learning_rate()
self._log_time()
self._log_histograms()
# log the weights of the different loss components
for key, weight in self.loss.weights_scheduling.items():
self.log(f"loss/{key}/weight", weight[self.current_epoch])

self._log_metrics(self.loss_metrics, "loss")
# this performs gather operations and logs only at rank == 0
self._log_figures_for_each_phase(
self.train_preds,
Expand All @@ -1312,15 +1293,32 @@ def on_train_epoch_end(self):
self.train_indices,
)

self._log_learning_rate()
self._log_time()
self._log_histograms()
# log the weights of the different loss components
if self.trainer.is_global_zero:
for key, weight in self.loss.weights_scheduling.items():
self.log(
f"loss/{key}/weight",
weight[self.current_epoch],
rank_zero_only=True,
)

def _log_learning_rate(self):
"""Logs the current learning rate."""
sch = self.lr_schedulers()
try:
self.log(
"lr", sch.get_last_lr()[0], on_epoch=True, prog_bar=True, sync_dist=True
)
except AttributeError:
pass
if self.trainer.is_global_zero:
try:
self.log(
"lr",
sch.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
rank_zero_only=True,
)
except AttributeError:
pass

def _log_metrics(self, metrics: ModuleDict, phase: str):
"""
Expand Down

0 comments on commit 9115792

Please sign in to comment.