Skip to content
/ LitMNIST Public template

Deep learning project template based on PyTorch Lightning and Hydra.

License

Notifications You must be signed in to change notification settings

XavierJiezou/LitMNIST

Folders and files

NameName
Last commit message
Last commit date

Latest commit

9241f35 · May 26, 2022
No commit message
May 20, 2022
May 26, 2022
No commit message
May 26, 2022
May 26, 2022
No commit message
May 23, 2022
No commit message
May 20, 2022
No commit message
May 20, 2022
May 26, 2022
No commit message
May 21, 2022
No commit message
May 20, 2022
No commit message
May 21, 2022
No commit message
May 23, 2022
No commit message
May 6, 2022
May 6, 2022
No commit message
May 26, 2022
No commit message
May 25, 2022
No commit message
May 20, 2022
No commit message
May 22, 2022
No commit message
May 26, 2022
No commit message
May 21, 2022
No commit message
May 21, 2022

Repository files navigation

logo

LitMNIST

基于 PyTorch Lightning + Hydra 的深度学习项目模板。

(以 MNIST 分类任务为例)

点击 Use this template 即可使用该模板来初始化你的新仓库。

GitHub Workflow Release Status GitHub Workflow Test Status GitHub Workflow Lint Status Codacy Badge codecov PyPI PyPI - Downloads GitHub stars GitHub forks GitHub issues GitHub license

Python PyTorch Lightning Config: hydra Code style: black

观看演示报告错误功能需求

English 简体中文

喜欢这个项目吗?请考虑捐赠(微信 | 支付宝),以帮助它改善!

目录

LitMNIST

演示

demo

安装

开始之前,你必须熟练使用 PyTorch Lightning,并对 Hydra 有一定的了解。

  1. 克隆仓库到本地
git clone https://github.com/XavierJiezou/LitMNIST.git
cd LitMNIST
  1. 创建并激活 conda 虚拟环境
conda create -n myenv python=3.8
conda activate myenv
  1. 安装项目依赖包(如需安装 GPU 版 PyTorch,请参考官网安装教程
pip install -r requirements.txt

运行

CPU

python train.py

GPU

python train.py \
trainer.gpus=4 \
+trainer.strategy=ddp_find_unused_parameters_false \
datamodule.num_workers=16 \
datamodule.pin_memory=True \
datamodule.persistent_workers=True

结构

项目的主要目录结构如下:

├── configs # 存放 Hydra 配置文件
│   ├── callbacks # Callbacks 配置(例如 EarlyStopping、ModelCheckpoint 等)
│   ├── datamodule # Datamodule 配置(例如 batch_size、num_workers 等)
│   ├── debug # 调试配置
│   ├── experiment # 实验配置
│   ├── hparams_search # 超参数搜索配置
│   ├── local # 本地配置(暂时可以忽略)
│   ├── log_dir # 日志存放目录配置
│   ├── logger # 日志配置
│   ├── model # 模型配置
│   ├── trainer # Trainer 配置
│   │
│   ├── test.yaml # 测试的主要配置
│   └── train.yaml # 训练的主要配置
│
├── data # 存放项目数据
│
├── logs # 存放项目日志(Hydra 日志 和 PyTorch Lightning loggers 生成的日志)
│
├── src # 项目源代码
│   ├── datamodules # LightningDataModule
│   ├── models # 存放基于原生 PyTorch 框架编写的模型
│   ├── litmodules # LightningModule
│   ├── utils # 存放一些实用的脚本(例如数据预处理的脚本)
│   │
│   ├── testing_pipeline.py # 测试流水线(实例化对象)
│   └── training_pipeline.py # 训练流水线(实例化对象)
│
├── tests # 单元测试(可选)
│
├── test.py # 开始测试(加载配置文件)
├── train.py # 开始训练(加载配置文件)
│
├── .env # 存储私有环境变量(例如 wandb 的 API_KEY)【注意:该文件不受版本控制】
├── .gitignore # 设置版本控制需要排除的文件或目录(例如 .env 文件)
├── requirements.txt # 项目依赖环境(pip install -r requirements.txt)
└── README.md # 项目概述文档

用法

本仓库是一个基于 PyTorch Lightning + Hydra 的深度学习项目模板。因此你仅需要套用该模板,并作出如下修改:

  1. 编写你自己的 PyTorch nn.Module 模型(参见 src/models/simple_densenet.py
  2. 编写你自己的 PyTorch Lightning LightningModule(参见 src/litmodules/mnist_litmodule.py
  3. 编写你自己的 PyTorch Lightning LightningDataModule(参见 src/datamodules/mnist_datamodule.py
  4. 编写你自己的实验配置文件(参见 configs/experiment/example.yaml
  5. 使用选定的实验配置运行训练代码:python train.py experiment=<EXPERIMENT_NAME>

基础

train.py 集成了模型训练验证测试的一整套工作流,安装好环境后,运行即可:

python train.py

test.py 仅包含测试步骤,允许你单独加载预训练模型进行测试(但要指定模型路径):

python test.py ckpt_path=checkpoints/last.ckpt

进阶

  • 从命令行覆盖任何配置参数

Hydra 允许你轻松覆盖配置文件中定义的任何参数。

train.py 默认从 configs/train.yaml 中获取参数。因此,你可以先修改 yaml 配置文件中的参数,然后再运行。

或者,你也可以在命令行中直接指定参数。命令行中参数的优先级要大于 yaml 配置文件中参数的优先级。

python train.py trainer.max_epochs=3

对于某些不太重要的参数,它们没有在 yaml 配置文件中定义,因此你在命令行中指定的时候必须添加 +

python train.py +trainer.precision=16
  • 在 CPU、GPU、多 GPU 和 TPU 上训练

PyTorch Lightning 使得在不同硬件上训练模型变得容易。

在 CPU 上训练

python train.py trainer.gpus=0

在 GPU 上训练

python train.py trainer.gpus=1

在 TPU 上训练

python train.py +trainer.tpu_cores=8

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【4 个 GPU】

python train.py trainer.gpus=4 +trainer.strategy=ddp

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【8 个 GPU,两个节点】

python train.py trainer.gpus=4 +trainer.num_nodes=2 +trainer.strategy=ddp
  • 混合精度训练

PyTorch Lightning 允许你使用半精度或混合精度以减少训练期间的内存占用。(在 GPU 上能够实现 3 倍的加速效果,但可能损失精度)

python train.py trainer.gpus=1 +trainer.precision=16
  • 使用 PyTorch Lightning 中的日志记录器来记录训练日志

PyTorch Lightning 集成了多种主流日志记录框架,包括 TensorBoard 和 Weights&Biases等。

这里以 wandb 为例,展示如何使用:

  1. 安装 wandb
pip install wandb
  1. 转到 wandb.ai/authorize 获取 API key

  2. 执行 login 命令(需要用到上一步获取的 API key

wandb login
  1. configs/logger/ 目录下新建一个名为 wandb.yaml 的文件,并写入以下内容
wandb:
  _target_: pytorch_lightning.loggers.wandb.WandbLogger
  project: "mnist"
  1. 执行训练代码的时候指定 loggerwandb
python train.py logger=wandb
  • 根据自定义实验配置来训练模型

配置文件见 configs/experiment/

python train.py experiment=example
  • 带回调函数的训练

配置文件见 configs/callbacks/

python train.py callbacks=default
  • 使用 Pytorch Lightning 中的训练策略

点击这里了解 Pytorch Lightning 中的各种训练策略

梯度裁剪来避免梯度爆炸

python train.py +trainer.gradient_clip_val=0.5

随机加权平均可以使您的模型更好地泛化

python train.py +trainer.stochastic_weight_avg=true

梯度累积

python train.py +trainer.accumulate_grad_batches=10
  • 轻松调试

配置文件见 configs/debug/

默认调试模式(运行 1 个 epoch)

python train.py debug=default

仅对 test epoch 进行调试

python train.py debug=test_only

执行一次 train,val 和 test 步骤(仅使用 1 个 batch)

python train.py +trainer.fast_dev_run=true

训练完成后打印各个阶段的执行时间(用于快速发现训练瓶颈)

python train.py +trainer.profiler="simple"
  • 断点续训
python train.py trainer.resume_from_checkpoint="/path/to/name.ckpt"
  • 一次执行多个实验

例如,下方代码将按顺序运行所有参数组合(共 6 个)的实验。

python train.py -m datamodule.batch_size=32,64,128 litmodule.lr=0.001,0.0005

此外,你也可以执行 /configs/experiment/ 目录下的的所有实验

python train.py -m 'experiment=glob(*)'
  • 使用 Optuna 进行超参数搜索

Optuna Sweeper plugin | Hydra

  1. 安装 hydra-optuna-sweeper 插件
pip install hydra-optuna-sweeper
  1. 修改 configs/hparams_search/ 目录下的配置文件

  2. 执行训练代码的时候指定 hparams_search

python train.py -m hparams_search=mnist_optuna
  • 使用 Tab 键智能提示可选配置参数

Tab completion | Hydra

$ eval "$(python train.py -sc install=bash)" # 安装
$ python train.py logger= # 按下 Tab 键后会智能提示有哪些可选参数
logger=comet         logger=csv           logger=many_loggers  logger=mlflow        logger=neptune       logger=tensorboard   logger=wandb

提示

  • .env 文件中设置私有环境变量
  1. 例如,你可以将 cometAPI Key 添加到 .env 文件中
COMET_API_KEY="xxx"
  1. 并在配置文件 configs/logger/comet.yaml 中进行调用
comet:
  _target_: pytorch_lightning.loggers.comet.CometLogger
  api_key: ${oc.env:COMET_API_KEY}
  1. 在训练的时候指定 logger 参数为 comet
python train.py logger=comet

注意:.env 文件不应受版本控制,因此我们已将其添加到 .gitignore 文件中了。

推荐使用 PytorchLightning 官方提供的 torchmetrics 库来计算指标(像准确率,F1 score 和混淆矩阵等)。这对于多 GPU 训练尤为重要!并且,推荐对每个步骤使用不同的指标实例,以确保所有 GPU 进程都有正确的累积值。下面给出了一个简单示例。

from torchmetrics.classification.accuracy import Accuracy


class LitModel(LightningModule):
    def __init__(self)
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        ...
        acc = self.train_acc(predictions, targets)
        self.log("train/acc", acc)
        ...

    def validation_step(self, batch, batch_idx):
        ...
        acc = self.val_acc(predictions, targets)
        self.log("val/acc", acc)
        ...
  • 可以使用 DVC 对数据和模型这些大文件进行版本控制
dvc init
dvc add data/MNIST
git add data/MNIST.dvc data/.gitignore
git commit -m "Add raw data"

更新

CHANGELOG.md

证书

MIT License

参考

此模板引用了以下仓库并进行了一些细微的修改。

Readme Card