Skip to content

Commit 780c27b

Browse files
authored
Merge pull request #19 from ranamihir/improvements
Updated README + other minor improvements.
2 parents 434b748 + 2c18eb3 commit 780c27b

File tree

2 files changed

+123
-60
lines changed

2 files changed

+123
-60
lines changed

README.md

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,40 @@
99
</a>
1010
</p>
1111

12-
`pytorch-common` is a lightweight wrapper that contains PyTorch code that is common and (hopefully) helpful to most projects built on PyTorch. It is built with 3 main principles in mind:
12+
13+
# Overview
14+
15+
This repository contains PyTorch code that is common and (hopefully) helpful to most projects built on PyTorch.
16+
17+
It is a lightweight wrapper that contains PyTorch code that is common and (hopefully) helpful to most projects built on PyTorch. It is built with 3 main principles in mind:
1318
- Make use of PyTorch available to people without much in-depth knowledge of it while providing enormous flexibility and support for hardcore users
1419
- Under-the-hood optimization for fast and memory efficient performance
1520
- Ability to change all settings (e.g. model, loss, metrics, devices, hyperparameters, artifact directories, etc.) directly from config
1621

22+
1723
# Features
1824

1925
In a nutshell, it has code for:
2026
- Training / testing models
27+
- Option to retrain on all data (without performing evaluation on a separate data set)
2128
- Logging all common losses / eval metrics
2229
- `BasePyTorchDataset`, which has functions for:
2330
- Printing summary + useful statistics
24-
- Over-/under-sampling rows
25-
- Properly saving/loading/removing datasets (using appropriate pickle modules)
31+
- Over- / under-sampling rows
32+
- Properly saving / loading / removing datasets (using appropriate pickle modules)
2633
- `BasePyTorchModel`, which has:
2734
- `initialize_model()`:
2835
- Prints number of params + architecture
2936
- Allows initializing (all / given) weights for Conv, BatchNorm, Linear, Embedding layers
30-
- Provision to freeze/unfreeze (all / given) weights of model
37+
- Provision to freeze / unfreeze (all / given) weights of model
3138
- Sending model to device(s)
32-
- Saving/loading/removing/copying state dict / model checkpoints
39+
- Saving / loading / removing / copying state dict / model checkpoints
3340
- Disable above mentioned checkpointing from config for faster development
3441
- Early stopping
35-
- Properly sending model/optimizer/batch to device(s)
36-
- Defining custom train/test loss and evaluation criteria directly from config
37-
- Supports most common losses/metrics for regression and binary/multi-class/multi-label classification
42+
- Sample weighting
43+
- Properly sending model / optimizer / batch to device(s)
44+
- Defining custom train / test loss and evaluation criteria directly from config
45+
- Supports most common losses / metrics for regression and binary / multi-class / multi-label classification
3846
- May give as many as you like
3947
- Cleanly stopping training at any point without losing progress
4048
- Make predictions
@@ -43,73 +51,80 @@ In a nutshell, it has code for:
4351
- Loading back best (or any) model and printing + plotting all losses + eval metrics
4452
- etc.
4553

46-
# Installation
47-
To install this package, you must have [pytorch](https://pytorch.org/) (and [transformers](https://github.com/huggingface/transformers) for accessing NLP-based functionalities) installed.
48-
If you don't already have it, you can create a conda environment by running:
49-
```bash
50-
conda env create -f requirements.yaml`
51-
pip install -e . # or ".[nlp]" if required
52-
```
53-
which will create an environment called `pytorch_common` for you with all the required dependencies.
54-
5554

56-
The package can then be installed from source:
55+
# Installation
56+
To install this package, you must have [pytorch](https://pytorch.org/) (and [transformers](https://github.com/huggingface/transformers) for accessing NLP-based functionalities) installed. Then you can simply install this package from source:
5757
```bash
58-
git clone [email protected]:ranamihir/pytorch_common
58+
git clone [email protected]:ranamihir/pytorch_common.git
5959
cd pytorch_common
60+
conda env create -f requirements.yaml # If you don't already have a pytorch-enabled conda environment
61+
conda activate pytorch_common # <-- Replace with your environment name
6062
pip install .
6163
```
64+
which will create an environment called `pytorch_common` for you with all the required dependencies and this package installed.
6265

6366
If you'd like access to the NLP-related functionalities (specifically for [transformers](https://github.com/huggingface/transformers/)), make sure to install it as below instead:
6467
```bash
6568
pip install ".[nlp]"
6669
```
6770

68-
# Usage
6971

70-
The default [config](https://github.com/ranamihir/pytorch_common/blob/master/pytorch_common/configs/config.yaml) can be loaded, and overridden with a user-specified dictionary, as follows:
71-
```python
72-
from pytorch_common.config import load_pytorch_common_config
73-
74-
# Create your own config (or load from a yaml file)
75-
config_dict = {"batch_size_per_gpu": 5, "device": "cpu", "epochs": 2, "lr": 1e-3, "disable_checkpointing": True}
72+
# Usage
7673

77-
# Load the deault pytorch_common config, and then override it with your own custom one
78-
config = load_pytorch_common_config(config_dict)
79-
```
74+
Training a very simple (dummy) model is as easy as:
8075

81-
Then, training a (dummy) model is as easy as:
8276
```python
8377
from torch.utils.data import DataLoader
84-
from torch.optim import SGD
8578

86-
from pytorch_common.additional_configs import BaseDatasetConfig, BaseModelConfig
87-
from pytorch_common.datasets import create_dataset
79+
from pytorch_common.config import load_pytorch_common_config
8880
from pytorch_common.metrics import get_loss_eval_criteria
89-
from pytorch_common.models import create_model
9081
from pytorch_common.train_utils import train_model
9182
from pytorch_common.utils import get_model_performance_trackers
9283

93-
# Create your own objects here
94-
dataset_config = BaseDatasetConfig({"size": 5, "dim": 1, "num_classes": 2})
95-
model_config = BaseModelConfig({"in_dim": 1, "num_classes": 2})
96-
dataset = create_dataset("multi_class_dataset", dataset_config)
97-
train_loader = DataLoader(dataset, batch_size=config.train_batch_size)
98-
val_loader = DataLoader(dataset, batch_size=config.eval_batch_size)
99-
model = create_model("single_layer_classifier", model_config)
100-
optimizer = SGD(model.parameters(), lr=config.lr)
84+
# Load default pytorch_common config and override with your settings
85+
project_config_dict = ...
86+
config = load_pytorch_common_config(project_config_dict)
87+
88+
# Create your own training objects here
89+
train_loader = ...
90+
val_loader = ...
91+
model = ...
92+
optimizer = ...
10193

102-
# Use `pytorch_common` to get loss/eval criteria, initialize loggers, and train the model
94+
# Use `pytorch_common` to get loss / eval criteria, initialize loggers, and train the model
10395
loss_criterion_train, loss_criterion_eval, eval_criteria = get_loss_eval_criteria(config, reduction="mean")
10496
train_logger, val_logger = get_model_performance_trackers(config)
10597
return_dict = train_model(
10698
model, config, train_loader, val_loader, optimizer, loss_criterion_train, loss_criterion_eval, eval_criteria, train_logger, val_logger
10799
)
108100
```
109-
For more details on getting started, check out the [basic usage notebook](https://github.com/ranamihir/pytorch_common/blob/master/notebooks/basic_usage.ipynb) and other examples in the [notebooks](https://github.com/ranamihir/pytorch_common/blob/master/notebooks/) folder.
110101

111-
# Testing
102+
More detailed examples highlighting the full functionality of this package can be found in the [examples](https://github.com/ranamihir/pytorch_common/tree/master/examples) directory.
103+
104+
## Config
105+
106+
A powerful advantage of using this repository is the ability to change a large number of settings related to PyTorch, and more generally, deep learning, directly from YAML, instead of having to worry about making code changes.
107+
108+
To do this, all you need to do is invoke the `load_pytorch_common_config()` function (with your project dictionary as input, if required). This will allow you to edit all `pytorch_common` supported settings in your project dictionary / YAML, or use the default ones for those not specified. E.g.:
109+
110+
```python
111+
>>> from pytorch_common.config import load_pytorch_common_config
112+
113+
>>> config = load_pytorch_common_config() # Use default settings
114+
>>> print(config.batch_size_per_gpu)
115+
32
116+
>>> dictionary = {"vocab_size": 10_000, "batch_size_per_gpu": 64} # Override default settings and / or add project specific settings here
117+
>>> config = load_pytorch_common_config(dictionary)
118+
>>> print(config.batch_size_per_gpu)
119+
64
120+
>>> print(config.vocab_size)
121+
10000
122+
```
123+
124+
The list of all supported configuration settings can be found [here](https://github.com/ranamihir/pytorch_common/blob/master/pytorch_common/configs/config.yaml).
125+
112126

127+
# Testing
113128
Several unit tests are present in the [tests](https://github.com/ranamihir/pytorch_common/tree/master/tests) directory. You may manually run them with:
114129

115130
```bash
@@ -129,8 +144,8 @@ chmod +x install-hooks.sh
129144

130145
In the future, I intend to move the tests to CI.
131146

132-
# To-do's
133147

148+
# To-do's
134149
I have some enhancements in mind which I haven't gotten around to adding to this repo yet:
135150
- Adding automatic mixed precision training (AMP) to enable it directly from config
136151
- Enabling distributed training across servers
@@ -140,6 +155,6 @@ I have some enhancements in mind which I haven't gotten around to adding to this
140155

141156
This repo is a personal project, and as such, has not been as heavily tested. It is (and will likely always be) a work-in-progress, as I try my best to keep it current with the advancements in PyTorch.
142157

143-
If you come across any bugs, or have questions/suggestions, please consider opening an issue, [reaching out to me](mailto:[email protected]), or better yet, sending across a PR. :)
158+
If you come across any bugs, or have questions / suggestions, please consider opening an issue, [reaching out to me](mailto:[email protected]), or better yet, sending across a PR. :)
144159

145160
Author: [Mihir Rana](https://github.com/ranamihir)

pytorch_common/train_utils.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,17 @@ def train_model(
217217

218218
# Replace model checkpoint if required
219219
if not config.disable_checkpointing:
220-
logger.info("Replacing current best model checkpoint...")
221-
best_checkpoint_file = save_model(
222-
model,
223-
config,
224-
epoch,
225-
train_logger,
226-
val_logger,
227-
optimizer,
228-
scheduler,
229-
config_info_dict,
220+
replace_checkpoint(
221+
model=model,
222+
config=config,
223+
new_epoch=epoch,
224+
old_epoch=best_epoch,
225+
train_logger=train_logger,
226+
val_logger=val_logger,
227+
optimizer=optimizer,
228+
scheduler=scheduler,
229+
config_info_dict=config_info_dict,
230230
)
231-
remove_model(config, best_epoch, config_info_dict)
232-
logger.info("Done.")
233231

234232
best_epoch = epoch # Update best epoch
235233

@@ -239,9 +237,23 @@ def train_model(
239237
logger.info(f"Stopping early after {stop_epoch} epochs.")
240238
break
241239

240+
# Replace model checkpoint if required
241+
elif not config.disable_checkpointing:
242+
replace_checkpoint(
243+
model=model,
244+
config=config,
245+
new_epoch=epoch,
246+
old_epoch=epoch-1,
247+
train_logger=train_logger,
248+
val_logger=val_logger,
249+
optimizer=optimizer,
250+
scheduler=scheduler,
251+
config_info_dict=config_info_dict,
252+
)
253+
242254
stop_epoch = epoch # Update last epoch trained
243255
except KeyboardInterrupt: # Option to quit training with keyboard interrupt
244-
logger.warning("Keyboard Interrupted!")
256+
logger.warning("Keyboard Interrupted! Pausing training.")
245257
stop_epoch = epoch - 1 # Current epoch training incomplete
246258
break
247259

@@ -610,6 +622,8 @@ def _drop_unnecessary_keys(return_dict: _StringDict, all_keys: List[str], return
610622
loss = loss_criterion(outputs, targets)
611623
if sample_weighting:
612624
loss = loss_reduction_fn(loss * sample_weights / sample_weights.sum())
625+
if torch.isnan(loss).any().item():
626+
logger.warning("NaN value encountered for loss.")
613627
loss_value = loss.item()
614628
return_dict["losses"].append(loss_value)
615629

@@ -974,6 +988,8 @@ def remove_model(
974988
logger.info(f"Removing {checkpoint_type} checkpoint '{checkpoint_path}'...")
975989
remove_object(checkpoint_path)
976990
logger.info("Done.")
991+
elif epoch > 0:
992+
logger.warning(f"Could not remove checkpoint '{checkpoint_path}' since it doesn't exist.")
977993

978994

979995
def get_checkpoint_type_from_file(checkpoint_file: str) -> str:
@@ -1001,6 +1017,38 @@ def validate_checkpoint_type(checkpoint_type: str) -> None:
10011017
)
10021018

10031019

1020+
def replace_checkpoint(
1021+
model: nn.Module,
1022+
config: _Config,
1023+
new_epoch: int,
1024+
old_epoch: int,
1025+
train_logger: Optional[ModelTracker] = None,
1026+
val_logger: Optional[ModelTracker] = None,
1027+
optimizer: Optional[Optimizer] = None,
1028+
scheduler: Optional[object] = None,
1029+
config_info_dict: Optional[_StringDict] = None,
1030+
) -> str:
1031+
"""
1032+
Save the `new` model checkpoint and
1033+
delet the `old` one (if it exists).
1034+
"""
1035+
logger.info("Replacing current best model checkpoint...")
1036+
best_checkpoint_file = save_model(
1037+
model,
1038+
config,
1039+
new_epoch,
1040+
train_logger,
1041+
val_logger,
1042+
optimizer,
1043+
scheduler,
1044+
config_info_dict,
1045+
)
1046+
remove_model(config, old_epoch, config_info_dict)
1047+
logger.info("Done.")
1048+
1049+
return best_checkpoint_file
1050+
1051+
10041052
class EarlyStopping:
10051053
"""
10061054
Implements early stopping in PyTorch.

0 commit comments

Comments
 (0)