Skip to content

Commit

Permalink
caching He init variances, move input to CPU in FieldType.transform()
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabri95 committed Jun 30, 2021
2 parents 1145b55 + 0847b72 commit 8d9270d
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 20 deletions.
14 changes: 10 additions & 4 deletions e2cnn/nn/field_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def transform(self, input: torch.Tensor, element) -> torch.Tensor:
and its (induced) action on the base space.
.. warning ::
The input tensor is detached before the transformation, therefore no gradient is propagated back
through this operation.
This method is internally implemented using ```numpy```.
This means that the input tensor is detached (and moved to CPU) before the transformation, therefore no
gradient is propagated back through this operation.
.. seealso ::
Expand All @@ -229,9 +230,14 @@ def transform(self, input: torch.Tensor, element) -> torch.Tensor:
transformed tensor
"""
transformed = self.gspace.featurefield_action(input.detach().numpy(), self.representation, element)
if input.is_cuda:
import warnings
warnings.warn('The input tensor is on GPU. The `FieldType.transform()` operation is based on `numpy` and,'
' therefore, must temporarily move the tensor on CPU. This can cause performance issues.')

transformed = self.gspace.featurefield_action(input.detach().cpu().numpy(), self.representation, element)
transformed = np.ascontiguousarray(transformed)
return torch.from_numpy(transformed.astype(np.float32))
return torch.from_numpy(transformed.astype(np.float32)).to(device=input.device)

def restrict(self, id) -> 'FieldType':
r"""
Expand Down
59 changes: 46 additions & 13 deletions e2cnn/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,24 @@
__all__ = ["generalized_he_init", "deltaorthonormal_init"]


def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion):
def _generalized_he_init_variances(basisexpansion: BasisExpansion):
r"""
Initialize the weights of a convolutional layer with a generalized He's weight initialization method.
Compute the variances of the weights of a convolutional layer with a generalized He's weight initialization method.
Args:
tensor (torch.Tensor): the tensor containing the weights
basisexpansion (BasisExpansion): the basis expansion method
"""
# Initialization

assert tensor.shape == (basisexpansion.dimension(), )

vars = torch.ones_like(tensor)
vars = torch.ones((basisexpansion.dimension(),))

inputs_count = defaultdict(lambda: set())
basis_count = defaultdict(int)

for attr in basisexpansion.get_basis_info():
basis_info = list(basisexpansion.get_basis_info())

for attr in basis_info:
i, o = attr["in_irreps_position"], attr["out_irreps_position"]
in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"]
inputs_count[o].add(in_irrep)
Expand All @@ -39,13 +37,48 @@ def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion):
for o in inputs_count.keys():
inputs_count[o] = len(inputs_count[o])

for w, attr in enumerate(basisexpansion.get_basis_info()):
for w, attr in enumerate(basis_info):
i, o = attr["in_irreps_position"], attr["out_irreps_position"]
in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"]
vars[w] = 1. / math.sqrt(inputs_count[o] * basis_count[(in_irrep, o)])

# for i, o in basis_count.keys():
# print(i, o, inputs_count[o], basis_count[(i, o)])
return vars


cached_he_vars = {}


def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion, cache: bool = False):
r"""
Initialize the weights of a convolutional layer with a generalized He's weight initialization method.
Because the computation of the variances can be expensive, to save time on consecutive runs of the same model,
it is possible to cache the tensor containing the variance of each weight, for a specific ```basisexpansion```.
This can be useful if a network contains multiple convolution layers of the same kind (same input and output types,
same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. to perform
hyper-parameter search over learning rate or to repeat an experiment with different random seeds).
.. note ::
The variance tensor is cached in memory and therefore is only available to the current process.
Args:
tensor (torch.Tensor): the tensor containing the weights
basisexpansion (BasisExpansion): the basis expansion method
cache (bool, optional): cache the variance tensor. By default, ```cache=False```
"""
# Initialization

assert tensor.shape == (basisexpansion.dimension(),)

if cache and basisexpansion not in cached_he_vars:
cached_he_vars[basisexpansion] = _generalized_he_init_variances(basisexpansion)

if cache:
vars = cached_he_vars[basisexpansion]
else:
vars = _generalized_he_init_variances(basisexpansion)

tensor[:] = vars * torch.randn_like(tensor)

Expand Down
6 changes: 6 additions & 0 deletions e2cnn/nn/modules/r2_conv/basisexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,10 @@ def dimension(self) -> int:
"""
pass

@abstractmethod
def __hash__(self):
raise NotImplementedError()

@abstractmethod
def __eq__(self, other):
raise NotImplementedError()
44 changes: 43 additions & 1 deletion e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self,

# increment the position counter
last_weight_position += total_weights

def get_basis_names(self) -> List[str]:
return self._basis_to_ids

Expand Down Expand Up @@ -364,6 +364,48 @@ def forward(self, weights: torch.Tensor) -> torch.Tensor:
# return the new filter
return _filter

def __hash__(self):

_hash = 0
for io in self._representations_pairs:
n_pairs = self._in_count[io[0]] * self._out_count[io[1]]
_hash += hash(getattr(self, f"block_expansion_{io}")) * n_pairs

return _hash

def __eq__(self, other):
if not isinstance(other, BlocksBasisExpansion):
return False

if self.dimension() != other.dimension():
return False

if self._representations_pairs != other._representations_pairs:
return False

for io in self._representations_pairs:
if self._contiguous[io] != other._contiguous[io]:
return False

if self._weights_ranges[io] != other._weights_ranges[io]:
return False

if self._contiguous[io]:
if getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}"):
return False
if getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}"):
return False
else:
if torch.any(getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}")):
return False
if torch.any(getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}")):
return False

if getattr(self, f"block_expansion_{io}") != getattr(other, f"block_expansion_{io}"):
return False

return True


def _retrieve_indices(type: FieldType):
fiber_position = 0
Expand Down
20 changes: 18 additions & 2 deletions e2cnn/nn/modules/r2_conv/basisexpansion_singleblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self,

super(SingleBlockBasisExpansion, self).__init__()

self.basis = basis

# compute the mask of the sampled basis containing only the elements allowed by the filter
mask = np.zeros(len(basis), dtype=bool)
for b, attr in enumerate(basis):
Expand Down Expand Up @@ -69,7 +71,8 @@ def __init__(self,
if not any(norms):
raise EmptyBasisException
sampled_basis = sampled_basis[norms, ...]

self._mask = mask

self.attributes = [attr for b, attr in enumerate(attributes) if norms[b]]

# register the bases tensors as parameters of this module
Expand Down Expand Up @@ -111,7 +114,20 @@ def get_basis_info(self) -> Iterable:

def dimension(self) -> int:
return self.sampled_basis.shape[0]


def __eq__(self, other):
if isinstance(other, SingleBlockBasisExpansion):
return (
self.basis == other.basis and
torch.allclose(self.sampled_basis, other.sampled_basis) and
(self._mask == other._mask).all()
)
else:
return False

def __hash__(self):
return 10000 * hash(self.basis) + 100 * hash(self.sampled_basis) + hash(self._mask)


# dictionary storing references to already built basis tensors
# when a new filter tensor is built, it is also stored here
Expand Down
68 changes: 68 additions & 0 deletions test/nn/test_he_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ def test_one_block(self):
# t1 = FieldType(gspace, [gspace.regular_repr]*2)
# t2 = FieldType(gspace, [gspace.regular_repr]*3)
self.check(t1, t2)
self.check_caching(t1, t2)

def test_many_block_discontinuous(self):
gspace = Rot2dOnR2(8)
t1 = FieldType(gspace, list(gspace.representations.values()) * 2)
t2 = FieldType(gspace, list(gspace.representations.values()) * 3)
self.check(t1, t2)
self.check_caching(t1, t2)

def test_many_block_sorted(self):
gspace = Rot2dOnR2(8)
t1 = FieldType(gspace, list(gspace.representations.values()) * 2).sorted()
t2 = FieldType(gspace, list(gspace.representations.values()) * 3).sorted()
self.check(t1, t2)
self.check_caching(t1, t2)

def test_different_irreps_ratio(self):
N = 8
Expand All @@ -45,6 +48,33 @@ def test_different_irreps_ratio(self):
t1 = FieldType(gspace, [irreps_in]*3)
t2 = FieldType(gspace, [irreps_out]*3)
self.check(t1, t2)
self.check_caching(t1, t2)

def test_caching(self):

N = 7
gspace = Rot2dOnR2(N)

# try combinations of field types which are similar but should still be cached independently

irreps = directsum([gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)], name="irrepssum")
t1 = FieldType(gspace, [irreps] * 2)
t2 = FieldType(gspace, [irreps] * 3)
self.check_caching(t1, t2)

t1 = FieldType(gspace, [directsum([irreps]*2)])
t2 = FieldType(gspace, [directsum([irreps]*3)])
self.check_caching(t1, t2)

irreps = [gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)]
t1 = FieldType(gspace, irreps * 2)
t2 = FieldType(gspace, irreps * 3)
self.check_caching(t1, t2)

irreps = [gspace.fibergroup.irrep(k) for k in range(N // 2 + 1)]
t1 = FieldType(gspace, irreps * 2).sorted()
t2 = FieldType(gspace, irreps * 3).sorted()
self.check_caching(t1, t2)

def check(self, r1: FieldType, r2: FieldType):

Expand Down Expand Up @@ -83,6 +113,44 @@ def check(self, r1: FieldType, r2: FieldType):
self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=5e-2))
self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=1e-1, atol=6e-2))

def check_caching(self, r1: FieldType, r2: FieldType):

np.set_printoptions(precision=7, threshold=60000, suppress=True)

assert r1.gspace == r2.gspace

with torch.no_grad():

s = 7

cl = R2Conv(r1, r2, s,
# sigma=[0.01] + [0.6]*int(s//2),
frequencies_cutoff=3.)

from datetime import datetime

torch.manual_seed(42)
weights1 = torch.zeros_like(cl.weights.data)

start = datetime.now()
init.generalized_he_init(weights1, cl.basisexpansion, cache=True)
end = datetime.now()
elapsed1 = (end - start).total_seconds()

for i in range(10):
torch.manual_seed(42)
weights2 = torch.zeros_like(cl.weights.data)

start = datetime.now()
init.generalized_he_init(weights2, cl.basisexpansion, cache=True)
end = datetime.now()
elapsed2 = (end - start).total_seconds()

self.assertTrue(torch.allclose(weights1, weights2))

# this ensures that the first run was indeed the first time this was cached
self.assertTrue(elapsed1 >= 4 * elapsed2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8d9270d

Please sign in to comment.