PyTorch deep learning project made easy.
- PyTorch Template Project
- Python >= 3.5
- PyTorch >= 0.4
- tqdm (Optional for
test.py) - tensorboard >= 1.7.0 (Optional for TensorboardX)
- tensorboardX >= 1.2 (Optional for TensorboardX)
- Clear folder structure which is suitable for many deep learning projects.
.jsonconfig file support for more convenient parameter tuning.- Checkpoint saving and resuming.
- Abstract base classes for faster development:
BaseTrainerhandles checkpoint saving/resuming, training process logging, and more.BaseDataLoaderhandles batch generation, data shuffling, and validation data splitting.BaseModelprovides basic model summary.
pytorch-template/
│
├── train.py - main script to start training
├── test.py - evaluation of trained model
├── config.json - config file
│
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders
│ ├── base_model.py - abstract base class for models
│ └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│ └── data_loaders.py
│
├── data/ - default directory for storing input data
│
├── model/ - models, losses, and metrics
│ ├── loss.py
│ ├── metric.py
│ └── model.py
│
├── saved/ - default checkpoints folder
│ └── runs/ - default logdir for tensorboardX
│
├── trainer/ - trainers
│ └── trainer.py
│
└── utils/
├── util.py
├── logger.py - class for train logging
├── visualization.py - class for tensorboardX visualization support
└── ...
The code in this repo is an MNIST example of the template.
Try python3 train.py -c config.json to run code.
Config files are in .json format:
{
"name": "Mnist_LeNet", // training session name
"n_gpu": 1, // number of GPUs to use for training.
"arch": {
"type": "MnistModel", // name of model architecture to train
"args": {
}
},
"data_loader": {
"type": "MnistDataLoader", // selecting data loader
"args":{
"data_dir": "data/", // dataset path
"batch_size": 64, // batch size
"shuffle": true, // shuffle training data before splitting
"validation_split": 0.1 // validation data ratio
"num_workers": 2, // number of cpu processes to be used for data loading
}
},
"optimizer": {
"type": "Adam",
"args":{
"lr": 0.001, // learning rate
"weight_decay": 0, // (optional) weight decay
"amsgrad": true
}
},
"loss": "nll_loss", // loss
"metrics": [
"my_metric", "my_metric2" // list of metrics to evaluate
],
"lr_scheduler": {
"type": "StepLR", // learning rate scheduler
"args":{
"step_size": 50,
"gamma": 0.1
}
},
"trainer": {
"epochs": 100, // number of training epochs
"save_dir": "saved/", // checkpoints are saved in save_dir/name
"save_freq": 1, // save checkpoints every save_freq epochs
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
"tensorboardX": true, // enable tensorboardX visualization support
"log_dir": "saved/runs" // directory to save log files for visualization
}
}Add addional configurations if you need.
Modify the configurations in .json config files, then run:
python train.py --config config.json
You can resume from a previously saved checkpoint by:
python train.py --resume path/to/checkpoint
You can enable multi-GPU training by setting n_gpu argument of the config file to larger number.
If configured to use smaller number of gpu than available, first n devices will be used by default.
Specify indices of available GPUs by cuda environmental variable.
python train.py --device 2,3 -c config.json
This is equivalent to
CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
- Writing your own data loader
-
Inherit
BaseDataLoaderBaseDataLoaderis a subclass oftorch.utils.data.DataLoader, you can use either of them.BaseDataLoaderhandles:- Generating next batch
- Data shuffling
- Generating validation data loader by calling
BaseDataLoader.split_validation()
-
DataLoader Usage
BaseDataLoaderis an iterator, to iterate through batches:for batch_idx, (x_batch, y_batch) in data_loader: pass
-
Example
Please refer to
data_loader/data_loaders.pyfor an MNIST data loading example.
- Writing your own trainer
-
Inherit
BaseTrainerBaseTrainerhandles:- Training process logging
- Checkpoint saving
- Checkpoint resuming
- Reconfigurable performance monitoring for saving current best model, and early stop training.
- If config
monitoris set tomax val_accuracy, which means then the trainer will save a checkpointmodel_best.pthwhenvalidation accuracyof epoch replaces currentmaximum. - If config
early_stopis set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to theearly_stopoption, or just deleting the line of config.
- If config
-
Implementing abstract methods
You need to implement
_train_epoch()for your training process, if you need validation then you can implement_valid_epoch()as intrainer/trainer.py
-
Example
Please refer to
trainer/trainer.pyfor MNIST training.
- Writing your own model
-
Inherit
BaseModelBaseModelhandles:- Inherited from
torch.nn.Module __str__: Modify nativeprintfunction to prints the number of trainable parameters.
- Inherited from
-
Implementing abstract methods
Implement the foward pass method
forward()
-
Example
Please refer to
model/model.pyfor a LeNet example.
Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.
Metric functions are located in 'model/metric.py'.
You can monitor multiple metrics by providing a list in the configuration file, e.g.:
"metrics": ["my_metric", "my_metric2"],If you have additional information to be logged, in _train_epoch() of your trainer class, merge them with log as shown below before returning:
additional_log = {"gradient_norm": g, "sensitivity": s}
log = {**log, **additional_log}
return logYou can test trained model by running test.py passing path to the trained checkpoint by --resume argument.
To split validation data from a data loader, call BaseDataLoader.split_validation(), it will return a validation data loader, with the number of samples according to the specified ratio in your config file.
Note: the split_validation() method will modify the original data loader
Note: split_validation() will return None if "validation_split" is set to 0
You can specify the name of the training session in config files:
"name": "MNIST_LeNet",The checkpoints will be saved in save_dir/name/timestamp/checkpoint_epoch_n, with timestamp in mmdd_HHMMSS format.
A copy of config file will be saved in the same folder.
Note: checkpoints contain:
{
'arch': arch,
'epoch': epoch,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.mnt_best,
'config': self.config
}This template supports TensorboardX visualization.
- TensorboardX Usage
-
Install
Follow installation guide in TensorboardX.
-
Run training
Set
tensorboardXoption in config file true. -
Open tensorboard server
Type
tensorboard --logdir saved/runs/at the project root, then server will open athttp://localhost:6006
By default, values of loss and metrics specified in config file, and input image will be logged.
If you need more visualizations, use add_scalar('tag', data), add_image('tag', image), etc in the trainer._train_epoch method.
add_something() methods in this template are basically wrappers for those of tensorboardX.SummaryWriter module.
Note: You don't have to specify current steps, since WriterTensorboardX class defined at logger/visualization.py will track current steps.
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
Code should pass the Flake8 check before committing.
- Iteration-based training (instead of epoch-based)
- Multiple optimizers
- Configurable logging layout, checkpoint naming
-
visdomlogger support -
tensorboardXlogger support - Adding command line option for fine-tuning
- Multi-GPU support
- Update the example to PyTorch 0.4
- Learning rate scheduler
- Deprecate
BaseDataLoader, usetorch.utils.datainstesad - Load settings from
configfiles
This project is licensed under the MIT License. See LICENSE for more details
This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy