Skip to content

Commit

Permalink
correct FieldDropout. Fix #14
Browse files Browse the repository at this point in the history
- fix parameter to buffer
- fix dropout from pointwise to whole field
  • Loading branch information
Gabri95 committed Aug 25, 2020
2 parents 2d42e0b + 93840bb commit b753697
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 5 deletions.
40 changes: 35 additions & 5 deletions e2cnn/nn/modules/dropout/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,34 @@
from torch.nn import Parameter
from typing import List, Tuple, Any


__all__ = ["FieldDropout"]


def dropout_field(input: torch.Tensor, p: float, training: bool, inplace: bool):

if training:
shape = list(input.size())
shape[2] = 1

if input.device == torch.device('cpu'):
mask = torch.FloatTensor(*shape)
else:
device = input.device
mask = torch.cuda.FloatTensor(*shape, device=device)

mask = mask.uniform_() > p
mask = mask.to(torch.float)

if inplace:
input *= mask / (1. - p)
return input
else:
return input * mask / (1. - p)
else:
return input


class FieldDropout(EquivariantModule):

def __init__(self,
Expand Down Expand Up @@ -89,7 +114,7 @@ def __init__(self,
_indices[s] = torch.LongTensor(_indices[s])

# register the indices tensors as parameters of this module
self.register_parameter('indices_{}'.format(s), _indices[s])
self.register_buffer('indices_{}'.format(s), _indices[s])

self._order = list(self._contiguous.keys())

Expand Down Expand Up @@ -118,16 +143,21 @@ def forward(self, input: GeometricTensor) -> GeometricTensor:
for s in self._order:

indices = getattr(self, f"indices_{s}")

shape = input.shape[:1] + (self._nfields[s], s) + input.shape[2:]

if self._contiguous[s]:
# if the fields are contiguous, we can use slicing
out = F.dropout(input[:, indices[0]:indices[1], ...], self.p, self.training, self.inplace)
out = dropout_field(input[:, indices[0]:indices[1], ...].view(shape), self.p, self.training, self.inplace)
if not self.inplace:
output[:, indices[0]:indices[1], ...] = out
shape = input.shape[:1] + (self._nfields[s] * s, ) + input.shape[2:]
output[:, indices[0]:indices[1], ...] = out.view(shape)
else:
# otherwise we have to use indexing
out = F.dropout(input[:, indices[0], ...], self.p, self.training, self.inplace)
out = dropout_field(input[:, indices, ...].view(shape), self.p, self.training, self.inplace)
if not self.inplace:
output[:, indices[0], ...] = out
shape = input.shape[:1] + (self._nfields[s] * s, ) + input.shape[2:]
output[:, indices, ...] = out.view(shape)

if self.inplace:
output = input
Expand Down
126 changes: 126 additions & 0 deletions test/nn/test_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import unittest
from unittest import TestCase

from e2cnn.nn import *
from e2cnn.gspaces import *

import torch
import torch.nn.functional as F
import numpy as np

import random


class TestDropout(TestCase):

def test_pointwise_do_unsorted_inplace(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3)

do = PointwiseDropout(r, inplace=True)

self.check_do(do)

def test_pointwise_do_unsorted(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3)

do = PointwiseDropout(r)

self.check_do(do)

def test_pointwise_do_sorted_inplace(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3).sorted()

do = PointwiseDropout(r, inplace=True)

self.check_do(do)

def test_pointwise_do_sorted(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3).sorted()

do = PointwiseDropout(r)

self.check_do(do)

def test_field_do_sorted(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, list(g.representations.values())*3).sorted()

bn = FieldDropout(r)

self.check_do(bn)

def test_field_do_unsorted(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, list(g.representations.values())*3)

bn = FieldDropout(r)

self.check_do(bn)

def test_field_do_sorted_inplace(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, list(g.representations.values())*3).sorted()

bn = FieldDropout(r, inplace=True)

self.check_do(bn)

def test_field_do_unsorted_inplace(self):
N = 8
g = FlipRot2dOnR2(N)

r = FieldType(g, list(g.representations.values())*3)

bn = FieldDropout(r, inplace=True)

self.check_do(bn)

def check_do(self, do: EquivariantModule):

x = 5 * torch.randn(3000, do.in_type.size, 20, 20) + 10
x = torch.abs(x)
x1 = x
x2 = x.clone()
x1 = GeometricTensor(x1, do.in_type)
x2 = GeometricTensor(x2, do.in_type)

do.train()

y1 = do(x1)

do.eval()

y2 = do(x2)

y1 = y1.tensor.permute(1, 0, 2, 3).reshape(do.in_type.size, -1)
y2 = y2.tensor.permute(1, 0, 2, 3).reshape(do.in_type.size, -1)

m1 = y1.mean(1)
m2 = y2.mean(1)

# print(m1)
# print(m2)

self.assertTrue(torch.allclose(m1, m2, rtol=5e-3, atol=5e-3))


if __name__ == '__main__':
unittest.main()

0 comments on commit b753697

Please sign in to comment.