diff --git a/src/nn_core/model_logging.py b/src/nn_core/model_logging.py index ba843f0..b5d9a4f 100644 --- a/src/nn_core/model_logging.py +++ b/src/nn_core/model_logging.py @@ -2,14 +2,15 @@ import logging import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, TypeVar, Union import hydra import pytorch_lightning from omegaconf import DictConfig, OmegaConf from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger +from pytorch_lightning.utilities import rank_zero_only from nn_core.common import PROJECT_ROOT @@ -19,6 +20,25 @@ _STATS_KEY: str = "stats" +T = TypeVar("T") + + +class MetricTracker: + def __init__(self, choice_fn: Callable[[T, T], T]): + self.choice_fn = choice_fn + self.best_values: Dict[str, T] = {} + + def __call__(self, name: str, value: Optional[T]) -> Optional[T]: + old_value = self.best_values.get(name, None) + + if value is None: + return old_value + + self.best_values[name] = self.choice_fn(old_value, value) if name in self.best_values else value + + return self.best_values[name] + + class NNLogger(LightningLoggerBase): __doc__ = LightningLoggerBase.__doc__ @@ -38,7 +58,12 @@ def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional self.logging_cfg.logger.mode = "offline" pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>") - self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id) + self.wrapped: WandbLogger = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id) + + self.metric_trackers = { + "max": MetricTracker(choice_fn=max), + "min": MetricTracker(choice_fn=min), + } # force experiment lazy initialization _ = self.wrapped.experiment @@ -53,6 +78,7 @@ def watch_model(self, pl_module: LightningModule): pylogger.info("Starting to 'watch' the module") self.wrapped.watch(pl_module, **self.logging_cfg["wandb_watch"]) + @rank_zero_only def upload_source(self) -> None: if self.logging_cfg.upload.source and self.wandb: pylogger.info("Uploading source code to wandb") @@ -83,6 +109,7 @@ def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, check "run_path" ] = f"{trainer.logger.experiment.entity}/{trainer.logger.experiment.project_name()}/{trainer.logger.version}" + @rank_zero_only def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # Log the checkpoint meta information self.add_path(obj_id="checkpoints/best", obj_path=checkpoint_callback.best_model_path) @@ -105,6 +132,7 @@ def experiment(self) -> Any: """Return the experiment object associated with this logger.""" return self.wrapped.experiment + @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Records metrics. @@ -116,8 +144,15 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ + tracked_metrics = {} + for key, value in metrics.items(): + for tracker_name, tracker in self.metric_trackers.items(): + tracked_metrics[f"{tracker_name}/{key}"] = tracker(name=f"{tracker_name}/{key}", value=value) + + self.wrapped.log_metrics(metrics=tracked_metrics, step=step) return self.wrapped.log_metrics(metrics=metrics, step=step) + @rank_zero_only def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): """Record hyperparameters. @@ -131,6 +166,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): "The whole configuration is already logged by logger.log_configuration, set logger=False" ) + @rank_zero_only def log_text(self, *args, **kwargs) -> None: """Log text. @@ -138,6 +174,7 @@ def log_text(self, *args, **kwargs) -> None: """ return self.wrapped.log_text(*args, **kwargs) + @rank_zero_only def log_image(self, *args, **kwargs) -> None: """Log image. @@ -160,6 +197,7 @@ def run_dir(self) -> str: # TODO: verify remote URLs handling return os.path.join(*map(str, (self.storage_dir, self.name, self.version))) + @rank_zero_only def log_configuration( self, model: pytorch_lightning.LightningModule,