From f8438b25ee5fcd2428f7a6b235b3ddc4eaddb16c Mon Sep 17 00:00:00 2001 From: danish nazir Date: Sun, 9 Apr 2023 08:29:44 +0000 Subject: [PATCH] Adding support for DDP training --- examples/train.py | 142 ++++++++++++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 55 deletions(-) diff --git a/examples/train.py b/examples/train.py index 609b4b14..c1c9e8ca 100644 --- a/examples/train.py +++ b/examples/train.py @@ -38,13 +38,14 @@ from torch.utils.data import DataLoader from torchvision import transforms - +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler from compressai.datasets import ImageFolder from compressai.losses import RateDistortionLoss from compressai.optimizers import net_aux_optimizer from compressai.zoo import image_models - - +import numpy as np +import os class AverageMeter: """Compute running average.""" @@ -71,6 +72,13 @@ def __getattr__(self, key): return getattr(self.module, key) +class DistributedDataParallelCompressionModel(nn.parallel.DistributedDataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + def configure_optimizers(net, args): """Separate parameters for the main optimizer and the auxiliary optimizer. Return two optimizers""" @@ -83,10 +91,9 @@ def configure_optimizers(net, args): def train_one_epoch( - model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, local_rank, device, world_size ): model.train() - device = next(model.parameters()).device for i, d in enumerate(train_dataloader): d = d.to(device) @@ -105,23 +112,21 @@ def train_one_epoch( aux_loss = model.aux_loss() aux_loss.backward() aux_optimizer.step() - - if i % 10 == 0: + if local_rank == 0 and i % 10 == 0: + print( f"Train epoch {epoch}: [" - f"{i*len(d)}/{len(train_dataloader.dataset)}" - f" ({100. * i / len(train_dataloader):.0f}%)]" - f'\tLoss: {out_criterion["loss"].item():.3f} |' - f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' - f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' + f"{(i * world_size )*len(d)}/{len(train_dataloader.dataset)}" + f" ({100. * (i ) / len(train_dataloader):.0f}%)]" + f'\tLoss: {out_criterion["loss"].item():.4f} |' + f'\tMSE loss: {out_criterion["mse_loss"].item():.4f} |' + f'\tBpp loss: {out_criterion["bpp_loss"].item():.4f} |' f"\tAux loss: {aux_loss.item():.2f}" ) -def test_epoch(epoch, test_dataloader, model, criterion): +def test_epoch(epoch, test_dataloader, model, criterion, device): model.eval() - device = next(model.parameters()).device - loss = AverageMeter() bpp_loss = AverageMeter() mse_loss = AverageMeter() @@ -138,6 +143,7 @@ def test_epoch(epoch, test_dataloader, model, criterion): loss.update(out_criterion["loss"]) mse_loss.update(out_criterion["mse_loss"]) + print( f"Test epoch {epoch}: Average losses:" f"\tLoss: {loss.avg:.3f} |" @@ -150,9 +156,11 @@ def test_epoch(epoch, test_dataloader, model, criterion): def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): - torch.save(state, filename) + cwd = os.getcwd() + filepath = os.path.join(cwd, filename) + torch.save(state, filepath) if is_best: - shutil.copyfile(filename, "checkpoint_best_loss.pth.tar") + shutil.copyfile(filepath, "checkpoint_best_loss.pth.tar") def parse_args(argv): @@ -160,13 +168,14 @@ def parse_args(argv): parser.add_argument( "-m", "--model", - default="bmshj2018-factorized", + default="bmshj2018-hyperprior", choices=image_models.keys(), help="Model architecture (default: %(default)s)", ) parser.add_argument( "-d", "--dataset", type=str, required=True, help="Training dataset" ) + parser.add_argument("--local_rank", type=int, default=0, help="Local rank. Necessary for using the torch.distributed.launch utility.") parser.add_argument( "-e", "--epochs", @@ -185,7 +194,7 @@ def parse_args(argv): "-n", "--num-workers", type=int, - default=4, + default=8, help="Dataloaders threads (default: %(default)s)", ) parser.add_argument( @@ -221,7 +230,7 @@ def parse_args(argv): parser.add_argument( "--save", action="store_true", default=True, help="Save model to disk" ) - parser.add_argument("--seed", type=int, help="Set random seed for reproducibility") + parser.add_argument("--seed", type=int, default=0, help="Set random seed for reproducibility") parser.add_argument( "--clip_max_norm", default=1.0, @@ -232,13 +241,33 @@ def parse_args(argv): args = parser.parse_args(argv) return args +def set_random_seeds(random_seed=0): + + torch.manual_seed(random_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(random_seed) + random.seed(random_seed) def main(argv): args = parse_args(argv) + world_size = int(os.environ["WORLD_SIZE"]) + if args.seed is not None: - torch.manual_seed(args.seed) - random.seed(args.seed) + print("setting random seed") + set_random_seeds(random_seed=args.seed) + + torch.distributed.init_process_group(backend="nccl") + + device = torch.device("cuda:{}".format(args.local_rank)) + + net = image_models[args.model](quality=3) + net = net.to(device) + + if args.cuda and torch.cuda.device_count() > 1: + #net = CustomDataParallel(net) + net = DistributedDataParallelCompressionModel(net,device_ids=[args.local_rank], output_device=args.local_rank) train_transforms = transforms.Compose( [transforms.RandomCrop(args.patch_size), transforms.ToTensor()] @@ -251,40 +280,38 @@ def main(argv): train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) - device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" + train_sampler = DistributedSampler(dataset=train_dataset) + test_sampler = DistributedSampler(dataset=test_dataset) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - shuffle=True, - pin_memory=(device == "cuda"), + sampler=train_sampler, ) test_dataloader = DataLoader( test_dataset, batch_size=args.test_batch_size, num_workers=args.num_workers, - shuffle=False, - pin_memory=(device == "cuda"), - ) + sampler=test_sampler - net = image_models[args.model](quality=3) - net = net.to(device) + ) - if args.cuda and torch.cuda.device_count() > 1: - net = CustomDataParallel(net) optimizer, aux_optimizer = configure_optimizers(net, args) - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode= "min") criterion = RateDistortionLoss(lmbda=args.lmbda) last_epoch = 0 - if args.checkpoint: # load from previous checkpoint + + if args.checkpoint : # load from previous checkpoint + dist.barrier() + map_location = {"cuda:0": "cuda:{}".format(args.local_rank)} print("Loading", args.checkpoint) - checkpoint = torch.load(args.checkpoint, map_location=device) + checkpoint = torch.load(args.checkpoint, map_location=map_location) last_epoch = checkpoint["epoch"] + 1 - net.load_state_dict(checkpoint["state_dict"]) + net.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) @@ -300,26 +327,31 @@ def main(argv): aux_optimizer, epoch, args.clip_max_norm, + args.local_rank, + device, + world_size ) - loss = test_epoch(epoch, test_dataloader, net, criterion) - lr_scheduler.step(loss) - - is_best = loss < best_loss - best_loss = min(loss, best_loss) - - if args.save: - save_checkpoint( - { - "epoch": epoch, - "state_dict": net.state_dict(), - "loss": loss, - "optimizer": optimizer.state_dict(), - "aux_optimizer": aux_optimizer.state_dict(), - "lr_scheduler": lr_scheduler.state_dict(), - }, - is_best, - ) - + + loss = test_epoch(epoch, test_dataloader, net, criterion,device) + dist.barrier() + dist.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) + lr_scheduler.step(loss / world_size) + is_best = (loss / world_size) < best_loss + best_loss = min((loss / world_size ), best_loss) + if args.local_rank == 0: + if args.save: + save_checkpoint( + { + "epoch": epoch, + "state_dict": net.state_dict(), + "loss": loss, + "optimizer": optimizer.state_dict(), + "aux_optimizer": aux_optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + }, + is_best, + ) + if __name__ == "__main__": main(sys.argv[1:])