Skip to content

Commit

Permalink
Fix bugs found from more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
cedrickchee committed Nov 5, 2017
1 parent e49e2f4 commit bfe137c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 32 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ $ python main.py
| Training batch size | 128 | --batch-size 128 |
| Testing batch size | 128 | --test-batch-size 128 |
| Loss threshold | 0.001 | --loss-threshold 0.001 |
| Log interval | 1 | --log-interval 1 |
| Log interval | 10 | --log-interval 10 |
| Disables CUDA training | false | --no-cuda |
| Num. of convolutional channel | 256 | --num-conv-channel 256 |
| Num. of primary unit | 8 | --num-primary-unit 8 |
| Primary unit size | 1152 | --primary-unit-size 1152 |
| Output unit size | 16 | --output-unit-size 16 |
| Num. routing iteration | 3 | --num-routing 3 |

## Results
Coming soon!
Expand Down
14 changes: 9 additions & 5 deletions capsule_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ class CapsuleLayer(nn.Module):
"""
The core implementation of the idea of capsules
"""
def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing, cuda):

def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing,
num_routing, cuda_enabled):
super(CapsuleLayer, self).__init__()

self.in_unit = in_unit
self.in_channel = in_channel
self.num_unit = num_unit
self.use_routing = use_routing
self.cuda = cuda
self.num_routing = num_routing
self.cuda_enabled = cuda_enabled

if self.use_routing:
"""
Expand All @@ -50,7 +53,8 @@ def create_conv_unit(idx):
self.add_module("conv_unit" + str(idx), unit)
return unit

self.conv_units = [create_conv_unit(u) for u in range(self.num_unit)]
self.conv_units = [create_conv_unit(
u) for u in range(self.num_unit)]

@staticmethod
def squash(sj):
Expand Down Expand Up @@ -88,12 +92,12 @@ def routing(self, x):
# All the routing logits (b_ij in the paper) are initialized to zero.
b_ij = Variable(torch.zeros(
1, self.in_channel, self.num_unit, 1))
if self.cuda:
if self.cuda_enabled:
b_ij = b_ij.cuda()

# From the paper in the "Capsules on MNIST" section,
# the sample MNIST test reconstructions of a CapsNet with 3 routing iterations.
num_iterations = 3
num_iterations = self.num_routing

for iteration in range(num_iterations):
# Routing algorithm
Expand Down
34 changes: 15 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from __future__ import print_function
import argparse
import sys
import time

import torch
import torch.optim as optim
Expand Down Expand Up @@ -57,8 +55,7 @@ def train(model, data_loader, optimizer, epoch):
optimizer.step()

if batch_idx % args.log_interval == 0:
mesg = '{}\tEpoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
time.ctime(),
mesg = 'Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(data_loader.dataset),
Expand Down Expand Up @@ -87,7 +84,7 @@ def test(model, data_loader):
for data, target in data_loader:
target_indices = target
target_one_hot = utils.one_hot_encode(
target_indices, length=model.digits.num_units)
target_indices, length=model.digits.num_unit)

data, target = Variable(data, volatile=True), Variable(target_one_hot)

Expand Down Expand Up @@ -133,12 +130,12 @@ def main():
default=128, help='testing batch size. default=128')
parser.add_argument('--loss-threshold', type=float, default=0.0001,
help='stop training if loss goes below this threshold. default=0.0001')
parser.add_argument("--log-interval", type=int, default=1,
help='number of images after which the training loss is logged, default is 1')
parser.add_argument('--cuda', action='store_true',
help='set it to 1 for running on GPU, 0 for CPU')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status, default=10')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training, default=false')
parser.add_argument('--threads', type=int, default=4,
help='number of threads for data loader to use')
help='number of threads for data loader to use, default=4')
parser.add_argument('--seed', type=int, default=42,
help='random seed for training. default=42')
parser.add_argument('--num-conv-channel', type=int, default=256,
Expand All @@ -149,20 +146,18 @@ def main():
default=1152, help='primary unit size. default=1152')
parser.add_argument('--output-unit-size', type=int,
default=16, help='output unit size. default=16')
parser.add_argument('--num-routing', type=int,
default=3, help='number of routing iteration. default=3')

args = parser.parse_args()

print(args)

# Check GPU or CUDA is available
cuda = args.cuda
if cuda and not torch.cuda.is_available():
print(
"ERROR: No GPU/cuda is not available. Try running on CPU or run without --cuda")
sys.exit(1)
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if cuda:
if args.cuda:
torch.cuda.manual_seed(args.seed)

# Load data
Expand All @@ -174,10 +169,11 @@ def main():
num_primary_unit=args.num_primary_unit,
primary_unit_size=args.primary_unit_size,
output_unit_size=args.output_unit_size,
cuda=args.cuda)
num_routing=args.num_routing,
cuda_enabled=args.cuda)

if cuda:
model = model.cuda()
if args.cuda:
model.cuda()

optimizer = optim.Adam(model.parameters(), lr=args.lr)

Expand Down
14 changes: 9 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ class Net(nn.Module):
"""
A simple CapsNet with 3 layers
"""
def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size, output_unit_size, cuda):

def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size,
output_unit_size, num_routing, cuda_enabled):
"""
In the constructor we instantiate one ConvLayer module and two CapsuleLayer modules
and assign them as member variables.
"""
super(Net, self).__init__()

self.cuda = cuda
self.cuda_enabled = cuda_enabled

self.conv1 = ConvLayer(in_channel=1,
out_channel=num_conv_channel,
Expand All @@ -38,15 +40,17 @@ def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size, output
num_unit=num_primary_unit,
unit_size=primary_unit_size,
use_routing=False,
cuda=cuda)
num_routing=num_routing,
cuda_enabled=cuda_enabled)

# DigitCaps
self.digits = CapsuleLayer(in_unit=num_primary_unit,
in_channel=primary_unit_size,
num_unit=10,
unit_size=output_unit_size,
use_routing=True,
cuda=cuda)
num_routing=num_routing,
cuda_enabled=cuda_enabled)

def forward(self, x):
"""
Expand Down Expand Up @@ -74,7 +78,7 @@ def margin_loss(self, input, target, size_average=True):

# Calculate left and right max() terms.
zero = Variable(torch.zeros(1))
if self.cuda:
if self.cuda_enabled:
zero = zero.cuda()
m_plus = 0.9
m_minus = 0.1
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def load_mnist(args):

print('===> Loading training datasets')
training_set = datasets.MNIST(
'../data', train=True, download=True, transform=data_transform)
'./data', train=True, download=True, transform=data_transform)
training_data_loader = DataLoader(
training_set, batch_size=args.batch_size, shuffle=True, **kwargs)

print('===> Loading testing datasets')
testing_set = datasets.MNIST(
'../data', train=False, download=True, transform=data_transform)
'./data', train=False, download=True, transform=data_transform)
testing_data_loader = DataLoader(
testing_set, batch_size=args.test_batch_size, shuffle=True, **kwargs)

Expand Down

0 comments on commit bfe137c

Please sign in to comment.