diff --git a/.gitignore b/.gitignore
index 032f61b4c..04a064844 100644
--- a/.gitignore
+++ b/.gitignore
@@ -149,3 +149,6 @@ configs/local/default.yaml
/data/
/logs/
.env
+
+# Aim logging
+.aim
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b3ba143a6..4c42c8d86 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.3.0
+ rev: v4.4.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
@@ -19,7 +19,7 @@ repos:
# python code formatting
- repo: https://github.com/psf/black
- rev: 22.6.0
+ rev: 23.1.0
hooks:
- id: black
args: [--line-length, "99"]
@@ -33,21 +33,21 @@ repos:
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
- rev: v2.32.1
+ rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
# python docstring formatting
- repo: https://github.com/myint/docformatter
- rev: v1.4
+ rev: v1.5.1
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
- rev: 4.0.1
+ rev: 6.0.0
hooks:
- id: flake8
args:
@@ -60,14 +60,14 @@ repos:
# python security linter
- repo: https://github.com/PyCQA/bandit
- rev: "1.7.1"
+ rev: "1.7.5"
hooks:
- id: bandit
args: ["-s", "B101"]
# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
- rev: v2.7.1
+ rev: v3.0.0-alpha.6
hooks:
- id: prettier
types: [yaml]
@@ -75,13 +75,13 @@ repos:
# shell scripts linter
- repo: https://github.com/shellcheck-py/shellcheck-py
- rev: v0.8.0.4
+ rev: v0.9.0.2
hooks:
- id: shellcheck
# md formatting
- repo: https://github.com/executablebooks/mdformat
- rev: 0.7.14
+ rev: 0.7.16
hooks:
- id: mdformat
args: ["--number"]
@@ -94,7 +94,7 @@ repos:
# word spelling linter
- repo: https://github.com/codespell-project/codespell
- rev: v2.1.0
+ rev: v2.2.4
hooks:
- id: codespell
args:
@@ -103,13 +103,13 @@ repos:
# jupyter notebook cell output clearing
- repo: https://github.com/kynan/nbstripout
- rev: 0.5.0
+ rev: 0.6.1
hooks:
- id: nbstripout
# jupyter notebook linting
- repo: https://github.com/nbQA-dev/nbQA
- rev: 1.4.0
+ rev: 1.6.3
hooks:
- id: nbqa-black
args: ["--line-length=99"]
diff --git a/README.md b/README.md
index 0e0e0a430..c42306f77 100644
--- a/README.md
+++ b/README.md
@@ -28,8 +28,8 @@ _Suggestions are always welcome!_
**Why you might want to use it:**
-✅ Speed
-Rapidly iterate over models, datasets, tasks and experiments on different accelerators like multi-GPUs or TPUs.
+✅ Save on boilerplate
+Easily add new models, datasets, tasks, experiments, and train on different accelerators, like multi-GPU, TPU or SLURM clusters.
✅ Education
Thoroughly commented. You can use this repo as a learning resource.
@@ -46,7 +46,10 @@ Lightning and Hydra are still evolving and integrate many libraries, which means
Template is not really adjusted for building data pipelines that depend on each other. It's more efficient to use it for model prototyping on ready-to-use data.
❌ Overfitted to simple use case
-The configuration setup is built with simple lightning training in mind. You might need to put some effort to adjust it for different use cases, e.g. lightning lite.
+The configuration setup is built with simple lightning training in mind. You might need to put some effort to adjust it for different use cases, e.g. lightning fabric.
+
+❌ Might not support your workflow
+For example, you can't resume hydra-based multirun or hyperparameter search.
> **Note**: _Keep in mind this is unofficial community project._
@@ -319,9 +322,6 @@ python train.py debug=overfit
# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
python train.py +trainer.detect_anomaly=true
-# log second gradient norm of the model
-python train.py +trainer.track_grad_norm=2
-
# use only 20% of the data
python train.py +trainer.limit_train_batches=0.2 \
+trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
@@ -435,6 +435,12 @@ pre-commit run -a
> **Note**: Apply pre-commit hooks to do things like auto-formatting code and configs, performing code analysis or removing output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
+Update pre-commit hook versions in `.pre-commit-config.yaml` with:
+
+```bash
+pre-commit autoupdate
+```
+
@@ -818,7 +824,7 @@ You can use different optimization frameworks integrated with Hydra, like [Optun
The `optimization_results.yaml` will be available under `logs/task_name/multirun` folder.
-This approach doesn't support advanced techniques like prunning - for more sophisticated search, you should probably write a dedicated optimization task (without multirun feature).
+This approach doesn't support resuming interrupted search and advanced techniques like prunning - for more sophisticated search and workflows, you should probably write a dedicated optimization task (without multirun feature).
@@ -889,10 +895,13 @@ def on_train_start(self):
## Best Practices
-Use Miniconda for GPU environments
+Use Miniconda
+
+It's usually unnecessary to install full anaconda environment, miniconda should be enough (weights around 80MB).
+
+Big advantage of conda is that it allows for installing packages without requiring certain compilers or libraries to be available in the system (since it installs precompiled binaries), so it often makes it easier to install some dependencies e.g. cudatoolkit for GPU support.
-It's usually unnecessary to install full anaconda environment, miniconda should be enough.
-It often makes it easier to install some dependencies, like cudatoolkit for GPU support. It also allows you to access your environments globally.
+It also allows you to access your environments globally which might be more convenient than creating new local environment for every project.
Example installation:
@@ -901,6 +910,12 @@ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
```
+Update conda:
+
+```bash
+conda update -n base -c defaults conda
+```
+
Create new conda environment:
```bash
@@ -934,6 +949,12 @@ To reformat all files in the project use command:
pre-commit run -a
```
+To update hook versions in [.pre-commit-config.yaml](.pre-commit-config.yaml) use:
+
+```bash
+pre-commit autoupdate
+```
+
@@ -1035,7 +1056,7 @@ The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/
def training_step_end():
...
- def training_epoch_end():
+ def on_train_epoch_end():
...
def validation_step():
@@ -1044,7 +1065,7 @@ The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/
def validation_step_end():
...
- def validation_epoch_end():
+ def on_validation_epoch_end():
...
def test_step():
@@ -1053,7 +1074,7 @@ The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/
def test_step_end():
...
- def test_epoch_end():
+ def on_test_epoch_end():
...
def configure_optimizers():
@@ -1245,7 +1266,7 @@ git clone https://github.com/YourGithubName/your-repo-name
cd your-repo-name
# create conda environment and install dependencies
-conda env create -f environment.yaml
+conda env create -f environment.yaml -n myenv
# activate conda environment
conda activate myenv
diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml
index 20ed26710..59958b1e3 100644
--- a/configs/callbacks/early_stopping.yaml
+++ b/configs/callbacks/early_stopping.yaml
@@ -1,9 +1,9 @@
-# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html
+# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.EarlyStopping.html
# Monitor a metric and stop training when it stops improving.
# Look at the above link for more detailed information.
early_stopping:
- _target_: pytorch_lightning.callbacks.EarlyStopping
+ _target_: lightning.pytorch.callbacks.EarlyStopping
monitor: ??? # quantity to be monitored, must be specified !!!
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
patience: 3 # number of checks with no improvement after which training will be stopped
diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml
index eae62933d..9aaf7b780 100644
--- a/configs/callbacks/model_checkpoint.yaml
+++ b/configs/callbacks/model_checkpoint.yaml
@@ -1,9 +1,9 @@
-# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html
+# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.ModelCheckpoint.html
# Save the model periodically by monitoring a quantity.
# Look at the above link for more detailed information.
model_checkpoint:
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: null # directory to save the model file
filename: null # checkpoint filename
monitor: null # name of the logged metric which determines when model is improving
diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml
index 04da98d3a..b1fa2ada8 100644
--- a/configs/callbacks/model_summary.yaml
+++ b/configs/callbacks/model_summary.yaml
@@ -1,7 +1,7 @@
-# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html
+# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.RichModelSummary.html
# Generates a summary of all layers in a LightningModule with rich text formatting.
# Look at the above link for more detailed information.
model_summary:
- _target_: pytorch_lightning.callbacks.RichModelSummary
+ _target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 1 # the maximum depth of layer nesting that the summary will include
diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml
index b6be5b459..bd58cde10 100644
--- a/configs/callbacks/rich_progress_bar.yaml
+++ b/configs/callbacks/rich_progress_bar.yaml
@@ -1,6 +1,6 @@
-# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html
+# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.RichProgressBar.html
# Create a progress bar with rich text formatting.
# Look at the above link for more detailed information.
rich_progress_bar:
- _target_: pytorch_lightning.callbacks.RichProgressBar
+ _target_: lightning.pytorch.callbacks.RichProgressBar
diff --git a/configs/experiment/example.yaml b/configs/experiment/example.yaml
index 3aad80f27..b68c913b9 100644
--- a/configs/experiment/example.yaml
+++ b/configs/experiment/example.yaml
@@ -36,3 +36,5 @@ logger:
wandb:
tags: ${tags}
group: "mnist"
+ aim:
+ experiment: "mnist"
diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml
new file mode 100644
index 000000000..8f9f6adad
--- /dev/null
+++ b/configs/logger/aim.yaml
@@ -0,0 +1,28 @@
+# https://aimstack.io/
+
+# example usage in lightning module:
+# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
+
+# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
+# `aim up`
+
+aim:
+ _target_: aim.pytorch_lightning.AimLogger
+ repo: ${paths.root_dir} # .aim folder will be created here
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
+
+ # aim allows to group runs under experiment name
+ experiment: null # any string, set to "default" if not specified
+
+ train_metric_prefix: "train/"
+ val_metric_prefix: "val/"
+ test_metric_prefix: "test/"
+
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
+
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
+ log_system_params: true
+
+ # enable/disable tracking console logs (default value is true)
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml
index 423f41f66..e0789274e 100644
--- a/configs/logger/comet.yaml
+++ b/configs/logger/comet.yaml
@@ -1,7 +1,7 @@
# https://www.comet.ml
comet:
- _target_: pytorch_lightning.loggers.comet.CometLogger
+ _target_: lightning.pytorch.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
save_dir: "${paths.output_dir}"
project_name: "lightning-hydra-template"
diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml
index 844ec6718..fa028e9c1 100644
--- a/configs/logger/csv.yaml
+++ b/configs/logger/csv.yaml
@@ -1,7 +1,7 @@
# csv logger built in lightning
csv:
- _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""
diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml
index 3b441a901..f8fb7e685 100644
--- a/configs/logger/mlflow.yaml
+++ b/configs/logger/mlflow.yaml
@@ -1,7 +1,7 @@
# https://mlflow.org
mlflow:
- _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
# experiment_name: ""
# run_name: ""
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml
index 5df1e3427..8233c1400 100644
--- a/configs/logger/neptune.yaml
+++ b/configs/logger/neptune.yaml
@@ -1,7 +1,7 @@
# https://neptune.ai
neptune:
- _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project: username/lightning-hydra-template
# name: ""
diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml
index 29067c90f..2bd31f6d8 100644
--- a/configs/logger/tensorboard.yaml
+++ b/configs/logger/tensorboard.yaml
@@ -1,7 +1,7 @@
# https://www.tensorflow.org/tensorboard/
tensorboard:
- _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml
index d6d20e0ba..ece165889 100644
--- a/configs/logger/wandb.yaml
+++ b/configs/logger/wandb.yaml
@@ -1,7 +1,7 @@
# https://wandb.ai
wandb:
- _target_: pytorch_lightning.loggers.wandb.WandbLogger
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
diff --git a/configs/train.yaml b/configs/train.yaml
index 166f29a70..59b264a0a 100644
--- a/configs/train.yaml
+++ b/configs/train.yaml
@@ -33,8 +33,6 @@ task_name: "train"
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
-# appending lists from command line is currently not supported :(
-# https://github.com/facebookresearch/hydra/issues/1547
tags: ["dev"]
# set False to skip model training
@@ -44,6 +42,9 @@ train: True
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True
+# compile model for faster training with pytorch 2.0
+compile: False
+
# simply provide checkpoint path to resume training
ckpt_path: null
diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml
index 1a336e8b5..50905e7fd 100644
--- a/configs/trainer/default.yaml
+++ b/configs/trainer/default.yaml
@@ -1,4 +1,4 @@
-_target_: pytorch_lightning.Trainer
+_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
diff --git a/environment.yaml b/environment.yaml
index cd964a32c..f02b80161 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -21,9 +21,9 @@ channels:
# compatibility is usually guaranteed
dependencies:
- - pytorch>=1.10
- - torchvision>=0.11
- - pytorch-lightning=1.*
+ - pytorch=2.*
+ - torchvision=0.*
+ - pytorch-lightning=2.*
- torchmetrics=0.*
- hydra-core=1.*
- rich=13.*
diff --git a/requirements.txt b/requirements.txt
index bad1c748e..668a0ed06 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
# --------- pytorch --------- #
-torch>=1.10.0
-torchvision>=0.11.0
-pytorch-lightning==1.9.1
-torchmetrics==0.11.0
+torch>=2.0.0
+torchvision>=0.15.0
+lightning>=2.0.0
+torchmetrics>=0.11.4
# --------- hydra --------- #
hydra-core==1.3.2
@@ -14,6 +14,7 @@ hydra-optuna-sweeper==1.2.0
# neptune-client
# mlflow
# comet-ml
+# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
# --------- others --------- #
pyrootutils # standardizing the project root setup
diff --git a/src/data/mnist_datamodule.py b/src/data/mnist_datamodule.py
index b20ca4d5e..6dd176ec6 100644
--- a/src/data/mnist_datamodule.py
+++ b/src/data/mnist_datamodule.py
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Tuple
import torch
-from pytorch_lightning import LightningDataModule
+from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
@@ -10,8 +10,7 @@
class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
- A DataModule implements 5 key methods:
-
+ A DataModule implements 6 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
@@ -32,7 +31,7 @@ def teardown(self):
split, transform and process the data.
Read the docs:
- https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""
def __init__(
diff --git a/src/eval.py b/src/eval.py
index 8e59f3d5c..763dbb65c 100644
--- a/src/eval.py
+++ b/src/eval.py
@@ -2,9 +2,9 @@
import hydra
import pyrootutils
+from lightning import LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
-from pytorch_lightning import LightningDataModule, LightningModule, Trainer
-from pytorch_lightning.loggers import Logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
@@ -33,8 +33,8 @@
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
"""Evaluates given checkpoint on a datamodule testset.
- This method is wrapped in optional @task_wrapper decorator which applies extra utilities
- before and after the call.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
@@ -82,6 +82,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ utils.extras(cfg)
+
evaluate(cfg)
diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py
index 8827f05ce..e1bb76d93 100644
--- a/src/models/mnist_module.py
+++ b/src/models/mnist_module.py
@@ -1,7 +1,7 @@
-from typing import Any, List
+from typing import Any
import torch
-from pytorch_lightning import LightningModule
+from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy
@@ -10,15 +10,15 @@ class MNISTLitModule(LightningModule):
"""Example of LightningModule for MNIST classification.
A LightningModule organizes your PyTorch code into 6 sections:
- - Computations (init)
- - Train loop (training_step)
+ - Initialization (__init__)
+ - Train Loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Prediction Loop (predict_step)
- Optimizers and LR Schedulers (configure_optimizers)
Docs:
- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
"""
def __init__(
@@ -77,21 +77,10 @@ def training_step(self, batch: Any, batch_idx: int):
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
- # we can return here dict with any tensors
- # and then read it in some callback or in `training_epoch_end()` below
- # remember to always return loss from `training_step()` or backpropagation will fail!
- return {"loss": loss, "preds": preds, "targets": targets}
-
- def training_epoch_end(self, outputs: List[Any]):
- # `outputs` is a list of dicts returned from `training_step()`
-
- # Warning: when overriding `training_epoch_end()`, lightning accumulates outputs from all batches of the epoch
- # this may not be an issue when training on mnist
- # but on larger datasets/models it's easy to run into out-of-memory errors
-
- # consider detaching tensors before returning them from `training_step()`
- # or using `on_train_epoch_end()` instead which doesn't accumulate outputs
+ # return loss or backpropagation will fail
+ return loss
+ def on_train_epoch_end(self):
pass
def validation_step(self, batch: Any, batch_idx: int):
@@ -103,9 +92,7 @@ def validation_step(self, batch: Any, batch_idx: int):
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
- return {"loss": loss, "preds": preds, "targets": targets}
-
- def validation_epoch_end(self, outputs: List[Any]):
+ def on_validation_epoch_end(self):
acc = self.val_acc.compute() # get current val acc
self.val_acc_best(acc) # update best so far val acc
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
@@ -121,9 +108,7 @@ def test_step(self, batch: Any, batch_idx: int):
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
- return {"loss": loss, "preds": preds, "targets": targets}
-
- def test_epoch_end(self, outputs: List[Any]):
+ def on_test_epoch_end(self):
pass
def configure_optimizers(self):
@@ -131,7 +116,7 @@ def configure_optimizers(self):
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = self.hparams.optimizer(params=self.parameters())
if self.hparams.scheduler is not None:
diff --git a/src/train.py b/src/train.py
index 418a0b469..dad481dd2 100644
--- a/src/train.py
+++ b/src/train.py
@@ -1,11 +1,12 @@
from typing import List, Optional, Tuple
import hydra
+import lightning as L
import pyrootutils
-import pytorch_lightning as pl
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
-from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
-from pytorch_lightning.loggers import Logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
@@ -35,8 +36,8 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
- This method is wrapped in optional @task_wrapper decorator which applies extra utilities
- before and after the call.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
@@ -47,7 +48,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
- pl.seed_everything(cfg.seed, workers=True)
+ L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
@@ -77,6 +78,10 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
+ if cfg.get("compile"):
+ log.info("Compiling model!")
+ model = torch.compile(model)
+
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
@@ -102,6 +107,9 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ utils.extras(cfg)
# train the model
metric_dict, _ = train(cfg)
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index e816613bd..62ef73c24 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -1,12 +1,5 @@
+from src.utils.instantiatiators import instantiate_callbacks, instantiate_loggers
+from src.utils.logging_utils import log_hyperparameters
from src.utils.pylogger import get_pylogger
from src.utils.rich_utils import enforce_tags, print_config_tree
-from src.utils.utils import (
- close_loggers,
- extras,
- get_metric_value,
- instantiate_callbacks,
- instantiate_loggers,
- log_hyperparameters,
- save_file,
- task_wrapper,
-)
+from src.utils.utils import extras, get_metric_value, task_wrapper
diff --git a/src/utils/instantiatiators.py b/src/utils/instantiatiators.py
new file mode 100644
index 000000000..9fe57510a
--- /dev/null
+++ b/src/utils/instantiatiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from src.utils import pylogger
+
+log = pylogger.get_pylogger(__name__)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py
new file mode 100644
index 000000000..276323495
--- /dev/null
+++ b/src/utils/logging_utils.py
@@ -0,0 +1,50 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from src.utils import pylogger
+
+log = pylogger.get_pylogger(__name__)
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py
index 92ffa7189..62176cad9 100644
--- a/src/utils/pylogger.py
+++ b/src/utils/pylogger.py
@@ -1,6 +1,6 @@
import logging
-from pytorch_lightning.utilities import rank_zero_only
+from lightning.pytorch.utilities import rank_zero_only
def get_pylogger(name=__name__) -> logging.Logger:
diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py
index 916056db4..6df129aae 100644
--- a/src/utils/rich_utils.py
+++ b/src/utils/rich_utils.py
@@ -5,8 +5,8 @@
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
-from pytorch_lightning.utilities import rank_zero_only
from rich.prompt import Prompt
from src.utils import pylogger
diff --git a/src/utils/utils.py b/src/utils/utils.py
index 06fa33910..b0a81c3f8 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -1,66 +1,14 @@
-import time
import warnings
from importlib.util import find_spec
-from pathlib import Path
-from typing import Any, Callable, Dict, List
+from typing import Callable
-import hydra
from omegaconf import DictConfig
-from pytorch_lightning import Callback
-from pytorch_lightning.loggers import Logger
-from pytorch_lightning.utilities import rank_zero_only
from src.utils import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
-def task_wrapper(task_func: Callable) -> Callable:
- """Optional decorator that wraps the task function in extra utilities.
-
- Makes multirun more resistant to failure.
-
- Utilities:
- - Calling the `utils.extras()` before the task is started
- - Calling the `utils.close_loggers()` after the task is finished or failed
- - Logging the exception if occurs
- - Logging the output dir
- """
-
- def wrap(cfg: DictConfig):
-
- # execute the task
- try:
-
- # apply extra utilities
- extras(cfg)
-
- metric_dict, object_dict = task_func(cfg=cfg)
-
- # things to do if exception occurs
- except Exception as ex:
-
- # save exception to `.log` file
- log.exception("")
-
- # when using hydra plugins like Optuna, you might want to disable raising exception
- # to avoid multirun failure
- raise ex
-
- # things to always do after either success or exception
- finally:
-
- # display output dir path in terminal
- log.info(f"Output dir: {cfg.paths.output_dir}")
-
- # close loggers (even if exception occurs so multirun won't fail)
- close_loggers()
-
- return metric_dict, object_dict
-
- return wrap
-
-
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
@@ -91,87 +39,57 @@ def extras(cfg: DictConfig) -> None:
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
-def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
- """Instantiates callbacks from config."""
- callbacks: List[Callback] = []
-
- if not callbacks_cfg:
- log.warning("No callback configs found! Skipping..")
- return callbacks
-
- if not isinstance(callbacks_cfg, DictConfig):
- raise TypeError("Callbacks config must be a DictConfig!")
-
- for _, cb_conf in callbacks_cfg.items():
- if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
- log.info(f"Instantiating callback <{cb_conf._target_}>")
- callbacks.append(hydra.utils.instantiate(cb_conf))
-
- return callbacks
-
-
-def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
- """Instantiates loggers from config."""
- logger: List[Logger] = []
-
- if not logger_cfg:
- log.warning("No logger configs found! Skipping...")
- return logger
-
- if not isinstance(logger_cfg, DictConfig):
- raise TypeError("Logger config must be a DictConfig!")
-
- for _, lg_conf in logger_cfg.items():
- if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
- log.info(f"Instantiating logger <{lg_conf._target_}>")
- logger.append(hydra.utils.instantiate(lg_conf))
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
- return logger
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
-@rank_zero_only
-def log_hyperparameters(object_dict: dict) -> None:
- """Controls which config parts are saved by lightning loggers.
+ ...
- Additionally saves:
- - Number of model parameters
+ return metric_dict, object_dict
+ ```
"""
- hparams = {}
-
- cfg = object_dict["cfg"]
- model = object_dict["model"]
- trainer = object_dict["trainer"]
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
- if not trainer.logger:
- log.warning("Logger not found! Skipping hyperparameter logging...")
- return
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
- hparams["model"] = cfg["model"]
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
+ # so when using hparam search plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
- # save number of model parameters
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params/trainable"] = sum(
- p.numel() for p in model.parameters() if p.requires_grad
- )
- hparams["model/params/non_trainable"] = sum(
- p.numel() for p in model.parameters() if not p.requires_grad
- )
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.output_dir}")
- hparams["data"] = cfg["data"]
- hparams["trainer"] = cfg["trainer"]
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
- hparams["callbacks"] = cfg.get("callbacks")
- hparams["extras"] = cfg.get("extras")
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
- hparams["task_name"] = cfg.get("task_name")
- hparams["tags"] = cfg.get("tags")
- hparams["ckpt_path"] = cfg.get("ckpt_path")
- hparams["seed"] = cfg.get("seed")
+ return metric_dict, object_dict
- # send hparams to all loggers
- for logger in trainer.loggers:
- logger.log_hyperparams(hparams)
+ return wrap
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
@@ -192,23 +110,3 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
-
-
-def close_loggers() -> None:
- """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
-
- log.info("Closing loggers...")
-
- if find_spec("wandb"): # if wandb is installed
- import wandb
-
- if wandb.run:
- log.info("Closing wandb!")
- wandb.finish()
-
-
-@rank_zero_only
-def save_file(path: str, content: str) -> None:
- """Save file in rank zero mode (only on one process in multi-GPU setup)."""
- with open(path, "w+") as file:
- file.write(content)
diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py
index 5b0963fc6..614778fef 100644
--- a/tests/helpers/package_available.py
+++ b/tests/helpers/package_available.py
@@ -1,7 +1,7 @@
import platform
import pkg_resources
-from pytorch_lightning.accelerators import TPUAccelerator
+from lightning.fabric.accelerators import TPUAccelerator
def _package_available(package_name: str) -> bool: