Skip to content

Adding support for DDP training #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 87 additions & 55 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@

from torch.utils.data import DataLoader
from torchvision import transforms

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from compressai.datasets import ImageFolder
from compressai.losses import RateDistortionLoss
from compressai.optimizers import net_aux_optimizer
from compressai.zoo import image_models


import numpy as np
import os
class AverageMeter:
"""Compute running average."""

Expand All @@ -71,6 +72,13 @@ def __getattr__(self, key):
return getattr(self.module, key)


class DistributedDataParallelCompressionModel(nn.parallel.DistributedDataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)

def configure_optimizers(net, args):
"""Separate parameters for the main optimizer and the auxiliary optimizer.
Return two optimizers"""
Expand All @@ -83,10 +91,9 @@ def configure_optimizers(net, args):


def train_one_epoch(
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, local_rank, device, world_size
):
model.train()
device = next(model.parameters()).device

for i, d in enumerate(train_dataloader):
d = d.to(device)
Expand All @@ -105,23 +112,21 @@ def train_one_epoch(
aux_loss = model.aux_loss()
aux_loss.backward()
aux_optimizer.step()

if i % 10 == 0:
if local_rank == 0 and i % 10 == 0:

print(
f"Train epoch {epoch}: ["
f"{i*len(d)}/{len(train_dataloader.dataset)}"
f" ({100. * i / len(train_dataloader):.0f}%)]"
f'\tLoss: {out_criterion["loss"].item():.3f} |'
f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
f"{(i * world_size )*len(d)}/{len(train_dataloader.dataset)}"
f" ({100. * (i ) / len(train_dataloader):.0f}%)]"
f'\tLoss: {out_criterion["loss"].item():.4f} |'
f'\tMSE loss: {out_criterion["mse_loss"].item():.4f} |'
f'\tBpp loss: {out_criterion["bpp_loss"].item():.4f} |'
f"\tAux loss: {aux_loss.item():.2f}"
)


def test_epoch(epoch, test_dataloader, model, criterion):
def test_epoch(epoch, test_dataloader, model, criterion, device):
model.eval()
device = next(model.parameters()).device

loss = AverageMeter()
bpp_loss = AverageMeter()
mse_loss = AverageMeter()
Expand All @@ -138,6 +143,7 @@ def test_epoch(epoch, test_dataloader, model, criterion):
loss.update(out_criterion["loss"])
mse_loss.update(out_criterion["mse_loss"])


print(
f"Test epoch {epoch}: Average losses:"
f"\tLoss: {loss.avg:.3f} |"
Expand All @@ -150,23 +156,26 @@ def test_epoch(epoch, test_dataloader, model, criterion):


def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
torch.save(state, filename)
cwd = os.getcwd()
filepath = os.path.join(cwd, filename)
torch.save(state, filepath)
if is_best:
shutil.copyfile(filename, "checkpoint_best_loss.pth.tar")
shutil.copyfile(filepath, "checkpoint_best_loss.pth.tar")


def parse_args(argv):
parser = argparse.ArgumentParser(description="Example training script.")
parser.add_argument(
"-m",
"--model",
default="bmshj2018-factorized",
default="bmshj2018-hyperprior",
choices=image_models.keys(),
help="Model architecture (default: %(default)s)",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="Training dataset"
)
parser.add_argument("--local_rank", type=int, default=0, help="Local rank. Necessary for using the torch.distributed.launch utility.")
parser.add_argument(
"-e",
"--epochs",
Expand All @@ -185,7 +194,7 @@ def parse_args(argv):
"-n",
"--num-workers",
type=int,
default=4,
default=8,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
Expand Down Expand Up @@ -221,7 +230,7 @@ def parse_args(argv):
parser.add_argument(
"--save", action="store_true", default=True, help="Save model to disk"
)
parser.add_argument("--seed", type=int, help="Set random seed for reproducibility")
parser.add_argument("--seed", type=int, default=0, help="Set random seed for reproducibility")
parser.add_argument(
"--clip_max_norm",
default=1.0,
Expand All @@ -232,13 +241,33 @@ def parse_args(argv):
args = parser.parse_args(argv)
return args

def set_random_seeds(random_seed=0):

torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

def main(argv):
args = parse_args(argv)

world_size = int(os.environ["WORLD_SIZE"])

if args.seed is not None:
torch.manual_seed(args.seed)
random.seed(args.seed)
print("setting random seed")
set_random_seeds(random_seed=args.seed)

torch.distributed.init_process_group(backend="nccl")

device = torch.device("cuda:{}".format(args.local_rank))

net = image_models[args.model](quality=3)
net = net.to(device)

if args.cuda and torch.cuda.device_count() > 1:
#net = CustomDataParallel(net)
net = DistributedDataParallelCompressionModel(net,device_ids=[args.local_rank], output_device=args.local_rank)

train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
Expand All @@ -251,40 +280,38 @@ def main(argv):
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
train_sampler = DistributedSampler(dataset=train_dataset)
test_sampler = DistributedSampler(dataset=test_dataset)

train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=(device == "cuda"),
sampler=train_sampler,
)

test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=(device == "cuda"),
)
sampler=test_sampler

net = image_models[args.model](quality=3)
net = net.to(device)
)

if args.cuda and torch.cuda.device_count() > 1:
net = CustomDataParallel(net)

optimizer, aux_optimizer = configure_optimizers(net, args)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode= "min")
criterion = RateDistortionLoss(lmbda=args.lmbda)

last_epoch = 0
if args.checkpoint: # load from previous checkpoint

if args.checkpoint : # load from previous checkpoint
dist.barrier()
map_location = {"cuda:0": "cuda:{}".format(args.local_rank)}
print("Loading", args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location=device)
checkpoint = torch.load(args.checkpoint, map_location=map_location)
last_epoch = checkpoint["epoch"] + 1
net.load_state_dict(checkpoint["state_dict"])
net.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Expand All @@ -300,26 +327,31 @@ def main(argv):
aux_optimizer,
epoch,
args.clip_max_norm,
args.local_rank,
device,
world_size
)
loss = test_epoch(epoch, test_dataloader, net, criterion)
lr_scheduler.step(loss)

is_best = loss < best_loss
best_loss = min(loss, best_loss)

if args.save:
save_checkpoint(
{
"epoch": epoch,
"state_dict": net.state_dict(),
"loss": loss,
"optimizer": optimizer.state_dict(),
"aux_optimizer": aux_optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
is_best,
)


loss = test_epoch(epoch, test_dataloader, net, criterion,device)
dist.barrier()
dist.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
lr_scheduler.step(loss / world_size)
is_best = (loss / world_size) < best_loss
best_loss = min((loss / world_size ), best_loss)
if args.local_rank == 0:
if args.save:
save_checkpoint(
{
"epoch": epoch,
"state_dict": net.state_dict(),
"loss": loss,
"optimizer": optimizer.state_dict(),
"aux_optimizer": aux_optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
is_best,
)


if __name__ == "__main__":
main(sys.argv[1:])