diff --git a/docs/examples/pytorch/imagenet_training/README.md b/docs/examples/pytorch/imagenet_training/README.md index 251d51355..1e889863e 100644 --- a/docs/examples/pytorch/imagenet_training/README.md +++ b/docs/examples/pytorch/imagenet_training/README.md @@ -63,6 +63,14 @@ Node 1: python imagenet_training.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 1 [imagenet-folder with train and val folders] ``` +## Calculating dataloader performance + +You can use the training example to calculate the performance of torch and rocAL dataloaders - passing `calculate-ips` calculates the dataloader performance (in images per second) by running the dataloader without training for 3 epochs. You can choose the rocAL dataloaders by passing `rocal-cpu` or `rocal-gpu` + +```shell +python imagenet_training.py --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders] --calculate-ips +``` + ## Usage ```bash diff --git a/docs/examples/pytorch/imagenet_training/imagenet_training.py b/docs/examples/pytorch/imagenet_training/imagenet_training.py index 4fe15a03e..428073177 100644 --- a/docs/examples/pytorch/imagenet_training/imagenet_training.py +++ b/docs/examples/pytorch/imagenet_training/imagenet_training.py @@ -2,6 +2,7 @@ import os import random import shutil +import statistics import time import warnings from enum import Enum @@ -90,6 +91,8 @@ 'multi node data parallel training') parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark") +parser.add_argument('--calculate-ips', action='store_true', + help="calculate images per second for chosen dataloader") best_acc1 = 0 @@ -215,6 +218,13 @@ def __next__(self): return images, targets +def calc_ips(batch_size, time): + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + tbs = world_size * batch_size + assert time != 0, "time should be non-zero value" + return tbs/time + + def main(): args = parser.parse_args() @@ -332,6 +342,80 @@ def main_worker(gpu, ngpus_per_node, args): device = torch.device("mps") else: device = torch.device("cpu") + + local_rank = 0 + world_size = 1 + crop_size = 224 + + if args.distributed or args.gpu: + local_rank = args.rank if args.distributed else args.gpu + if args.world_size != -1: + world_size = args.world_size + if local_rank == None: + local_rank = 0 + + if args.rocal_gpu or args.rocal_cpu: + # Create rocAL dataloaders + get_train_loader = get_rocal_train_loader + train_loader = get_train_loader(data_path=args.data, batch_size=args.batch_size, local_rank=local_rank, world_size=world_size, + num_thread=args.workers, crop=crop_size, rocal_cpu=False if args.rocal_gpu else True, fp16=False) + else: + traindir = os.path.join(args.data, 'train') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=( + train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + if args.calculate_ips: + + avg_ips_values = [] + total_data_time = 0.0 + # Taking average images per second and dataloader time for 3 epochs + for epoch in range(3): + avg_ips = AverageMeter('data_ips', ':6.3f') + progress = ProgressMeter( + len(train_loader), + [avg_ips], + prefix="Epoch: [{}]".format(epoch)) + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + bs = images.size(0) + data_time = time.time() - end + avg_ips.update(calc_ips(bs, data_time)) + total_data_time += data_time + if i % 100 == 0: + if local_rank == 0: + progress.display(i + 1) + end = time.time() + + if args.rocal_gpu or args.rocal_cpu: + train_loader.reset() + avg_ips_values.append(avg_ips.avg) + if local_rank == 0: + print("Avg epoch Dataloader time: ", total_data_time / 3) + print("Average Images per second: ", + statistics.mean(avg_ips_values)) + return + # define loss function (criterion), optimizer, and learning rate scheduler criterion = nn.CrossEntropyLoss().to(device) @@ -373,38 +457,12 @@ def main_worker(gpu, ngpus_per_node, args): val_dataset = datasets.FakeData( 50000, (3, 224, 224), 1000, transforms.ToTensor()) if args.rocal_gpu or args.rocal_cpu: - get_train_loader = get_rocal_train_loader - get_val_loader = get_rocal_val_loader - local_rank = 0 - world_size = 1 - - crop_size = 224 - if args.distributed or args.gpu: - local_rank = args.rank if args.distributed else args.gpu - if args.world_size != -1: - world_size = args.world_size - if local_rank == None: - local_rank = 0 - train_loader = get_train_loader(data_path=args.data, batch_size=args.batch_size, local_rank=local_rank, world_size=world_size, - num_thread=args.workers, crop=crop_size, rocal_cpu=False if args.rocal_gpu else True, fp16=False) + get_val_loader = get_rocal_val_loader val_loader = get_val_loader(data_path=args.data, batch_size=args.batch_size, local_rank=local_rank, world_size=world_size, num_thread=args.workers, crop=crop_size, rocal_cpu=False if args.rocal_gpu else True, fp16=False) else: - traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ @@ -415,19 +473,11 @@ def main_worker(gpu, ngpus_per_node, args): ])) if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, shuffle=False, drop_last=True) else: - train_sampler = None val_sampler = None - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=( - train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler) - val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)