diff --git a/args.py b/args.py index 7ec848c..aaf7f61 100644 --- a/args.py +++ b/args.py @@ -19,6 +19,24 @@ def parse_arguments(): parser.add_argument( "-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture" ) + parser.add_argument( + "--num-train-examples", + type=int, + default=None, + help="Number of train examples to use for MultiMNIST", + ) + parser.add_argument( + "--num-val-examples", + type=int, + default=None, + help="Number of val examples to use for MultiMNIST", + ) + parser.add_argument( + "--num-concat", + help="Number of digits to concat MultiMNIST dataset", + type=int, + default=None, + ) parser.add_argument( "--config", help="Config file to use (see configs dir)", default=None ) @@ -244,7 +262,8 @@ def parse_arguments(): args = parser.parse_args() - get_config(args) + if args.config is not None: + get_config(args) return args diff --git a/configs/smallscale/multimnist/lenet5.yaml b/configs/smallscale/multimnist/lenet5.yaml new file mode 100644 index 0000000..67deabd --- /dev/null +++ b/configs/smallscale/multimnist/lenet5.yaml @@ -0,0 +1,35 @@ +# Architecture +arch: LeNet5 + +# ===== Dataset ===== # +data: /usr/data +set: MultiMNIST +name: baseline +num_train_examples: 5000000 +num_val_examples: 50000 +num_concat: 5 +num_classes: 100000 + + +# ===== Learning Rate Policy ======== # +optimizer: sgd +lr: 0.1 +lr_policy: cosine_lr +warmup_length: 5 + +# ===== Network training config ===== # +epochs: 100 +weight_decay: 0.0001 +momentum: 0.9 +batch_size: 256 + + +# ===== Sparsity =========== # +conv_type: DenseConv +bn_type: LearnedBatchNorm +init: kaiming_normal +mode: fan_in +nonlinearity: relu + +# ===== Hardware setup ===== # +workers: 12 diff --git a/data/__init__.py b/data/__init__.py index a011a41..70602e9 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,2 +1,3 @@ from data.imagenet import ImageNet -from data.imagenet import TinyImageNet \ No newline at end of file +from data.imagenet import TinyImageNet +from data.mnist import MultiMNIST \ No newline at end of file diff --git a/data/mnist.py b/data/mnist.py new file mode 100644 index 0000000..cdced19 --- /dev/null +++ b/data/mnist.py @@ -0,0 +1,102 @@ +from typing import Any, Callable, Optional, Tuple +from torchvision import datasets, transforms +from PIL import Image +from args import args +import os +import torch +import torchvision +import numpy as np + + +class MultiMNIST: + def __init__(self, args): + super(MultiMNIST, self).__init__() + + data_root = os.path.join(args.data, "mnist") + + use_cuda = torch.cuda.is_available() + + # Data loading code + kwargs = {"num_workers": args.workers, "pin_memory": True} if use_cuda else {} + self.train_loader = torch.utils.data.DataLoader( + MultiMNISTDataset( + data_root, + train=True, + download=True, + num_concat=args.num_concat, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs + ) + self.val_loader = torch.utils.data.DataLoader( + MultiMNISTDataset( + data_root, + num_concat=args.num_concat, + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs + ) + + +class MultiMNISTDataset(datasets.MNIST): + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download : bool = False, + num_concat : int = 1, + ) -> None: + super().__init__( + root, + train=train, + transform=transform, + target_transform=target_transform, + download=download, + ) + + self.length = int(super().__len__() ** num_concat) + if self.train: + self.length = args.num_train_examples or self.length + else: + self.length = args.num_val_examples or self.length + + self.num_concat = num_concat + + def __len__(self): + return self.length + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # Pick 4 random + if self.train: + rng = np.random.RandomState(index*2) + else: + rng = np.random.RandomState(index*2 + 1) + + + indices = rng.randint(0, super().__len__(), (self.num_concat,)) + img, target = self.data[indices], self.targets[indices] + base = 10 ** torch.arange(self.num_concat - 1, -1, -1) + + img = torch.cat([img[i] for i in range(self.num_concat)], dim=-1) + target = (base * target).sum() + + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target diff --git a/models/__init__.py b/models/__init__.py index aea32d0..c5b4116 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,5 +1,6 @@ from models.resnet import ResNet18, ResNet50 from models.mobilenetv1 import MobileNetV1 +from models.lenet import LeNet5 __all__ = [ "ResNet18", diff --git a/models/lenet.py b/models/lenet.py new file mode 100644 index 0000000..5ee80c9 --- /dev/null +++ b/models/lenet.py @@ -0,0 +1,39 @@ +""" +Lenet-5 implementation from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/model.py +""" + +from args import args +from torch.nn import Module +from torch import nn + + +class LeNet5(Module): + def __init__(self): + super(LeNet5, self).__init__() + + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(2) + self.fc1 = nn.Linear(256 + 448 * (args.num_concat - 1), 120) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, args.num_classes) + + def forward(self, x): + y = self.conv1(x) + y = self.relu1(y) + y = self.pool1(y) + y = self.conv2(y) + y = self.relu2(y) + y = self.pool2(y) + y = y.view(y.shape[0], -1) + y = self.fc1(y) + y = self.relu3(y) + y = self.fc2(y) + y = self.relu4(y) + y = self.fc3(y) + return y \ No newline at end of file diff --git a/tests/multimnist.py b/tests/multimnist.py new file mode 100644 index 0000000..bffdff6 --- /dev/null +++ b/tests/multimnist.py @@ -0,0 +1,31 @@ +import sys, os +sys.path.append(os.path.abspath('.')) + +from args import args +from data import MultiMNIST +from collections import defaultdict + +import seaborn as sns +import matplotlib.pyplot as plt +import tqdm + +args.data = "/usr/data" +args.num_train_examples = 200000 +args.num_val_examples = 50000 +args.num_concat = 4 +args.workers = 16 + +mnist = MultiMNIST(args) + +label_counts = defaultdict(int) +for i in tqdm.tqdm(range(len(mnist.train_loader.dataset)), ascii=True): + _, label = mnist.train_loader.dataset[i] + label_counts[label.item()] += 1 + + +fig, (ax1, ax2) = plt.subplots(2) + +sns.kdeplot(label_counts, ax=ax1) + +plt.plot(*zip(*sorted(label_counts.items()))) +plt.savefig("tests/images/multimnist.pdf", bbox_inches="tight") diff --git a/trainer.py b/trainer.py index 0e4a1da..d541e9d 100644 --- a/trainer.py +++ b/trainer.py @@ -43,8 +43,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args, writer): loss = criterion(output, target.view(-1)) + # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) diff --git a/utils/eval_utils.py b/utils/eval_utils.py index 45345c3..1147e6f 100644 --- a/utils/eval_utils.py +++ b/utils/eval_utils.py @@ -13,6 +13,6 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res