Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_parallel for network containers #985

Merged
merged 3 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 204 additions & 3 deletions alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""Some basic layers."""

from absl import logging
import copy

from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
from typing import Optional, Union, Callable
emailweixu marked this conversation as resolved.
Show resolved Hide resolved

import alf
from alf.initializers import variance_scaling_init
Expand Down Expand Up @@ -84,6 +84,9 @@ def __init__(self, dtype=torch.float32):
def forward(self, x):
return x.to(self._dtype)

def make_parallel(self, n: int):
return Cast(self._dtype)


class Transpose(nn.Module):
"""A layer that perform the transpose of channels.
Expand All @@ -109,6 +112,9 @@ def __init__(self, dim0=0, dim1=1):
def forward(self, x):
return x.transpose(self._dim0, self._dim1)

def make_parallel(self, n: int):
emailweixu marked this conversation as resolved.
Show resolved Hide resolved
return Transpose(self._dim0, self._dim1)


class Permute(nn.Module):
"""A layer that perform the permutation of channels."""
Expand All @@ -127,6 +133,9 @@ def __init__(self, *dims):
def forward(self, x):
return x.permute(*self._dims)

def make_parallel(self, n: int):
return Permute(*self._dims)


class BatchSquash(object):
"""Facilitates flattening and unflattening batch dims of a tensor. Copied
Expand Down Expand Up @@ -183,6 +192,9 @@ def forward(self, input):
return nn.functional.one_hot(
input, num_classes=self._num_classes).to(torch.float32)

def make_parallel(self, n: int):
return OneHot(self._num_classes)


@alf.configurable
class FixedDecodingLayer(nn.Module):
Expand Down Expand Up @@ -432,7 +444,7 @@ def weight(self):
def bias(self):
return self._bias

def make_parallel(self, n):
def make_parallel(self, n: int):
"""Create a ``ParallelFC`` using ``n`` replicas of ``self``.
The initialized layer parameters will be different.
"""
Expand Down Expand Up @@ -639,6 +651,8 @@ def __init__(self,
bias_init_value (float): a constant
"""
super().__init__()
self._input_size = input_size
self._output_size = output_size
self._activation = activation
self._weight = nn.Parameter(torch.Tensor(n, output_size, input_size))
if use_bias:
Expand Down Expand Up @@ -1085,6 +1099,11 @@ def __init__(self,
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')

super(Conv2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
Expand Down Expand Up @@ -1136,6 +1155,9 @@ def weight(self):
def bias(self):
return self._conv2d.bias

def make_parallel(self, n: int):
return ParallelConv2D(n=n, **self._kwargs)


@alf.configurable
class Conv2DBatchEnsemble(Conv2D):
Expand Down Expand Up @@ -1490,6 +1512,10 @@ def __init__(self,
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')
super(ConvTranspose2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
Expand Down Expand Up @@ -1533,6 +1559,9 @@ def weight(self):
def bias(self):
return self._conv_trans2d.bias

def make_parallel(self, n: int):
return ParallelConvTranspose2D(n=n, **self._kwargs)


@alf.configurable
class ParallelConvTranspose2D(nn.Module):
Expand All @@ -1544,6 +1573,7 @@ def __init__(self,
activation=torch.relu_,
strides=1,
padding=0,
output_padding=0,
use_bias=None,
use_bn=False,
kernel_initializer=None,
Expand All @@ -1560,6 +1590,9 @@ def __init__(self,
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
output_padding (int or tuple): Additional size added to one side of
each dimension in the output shape. Default: 0. See pytorch
documentation for more detail.
use_bias (bool|None): If None, will use ``not use_bn``
use_bn (bool):
kernel_initializer (Callable): initializer for the conv_trans layer.
Expand All @@ -1585,6 +1618,7 @@ def __init__(self,
groups=n,
stride=strides,
padding=padding,
output_padding=output_padding,
bias=use_bias)

for i in range(n):
Expand Down Expand Up @@ -2074,6 +2108,9 @@ def __init__(self, shape):
def forward(self, x):
return x.reshape(x.shape[0], *self._shape)

def make_parallel(self, n: int):
return Reshape((n, ) + self._shape)


def _tuplify2d(x):
if isinstance(x, tuple):
Expand Down Expand Up @@ -2604,6 +2641,9 @@ def forward(self, input):
return alf.nest.map_structure(
lambda path: alf.nest.get_field(input, path), self._fields)

def make_parallel(self, n: int):
return GetFields(self._fields)


class Sum(nn.Module):
"""Sum over given dimension(s).
Expand All @@ -2624,6 +2664,9 @@ def __init__(self, dim):
def forward(self, input):
return input.sum(dim=self._dim)

def make_parallel(self, n: int):
return Sum(self._dim)
emailweixu marked this conversation as resolved.
Show resolved Hide resolved


def reset_parameters(module):
"""Reset the parameters for ``module``.
Expand Down Expand Up @@ -2655,6 +2698,9 @@ def __init__(self):
def forward(self, input):
return common.detach(input)

def make_parallel(self, n: int):
return Detach()


class Branch(nn.Module):
"""Apply multiple modules on the same input.
Expand Down Expand Up @@ -2710,6 +2756,18 @@ def forward(self, inputs):
def reset_parameters(self):
alf.nest.map_structure(reset_parameters, self._networks)

def make_parallel(self, n: int):
"""Create a parallelized version of this network.

Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
new_networks = alf.nest.map_structure(
lambda net: make_parallel_net(net, n), self._networks)
return Branch(new_networks)


class Sequential(nn.Module):
"""A more flexible Sequential than torch.nn.Sequential.
Expand Down Expand Up @@ -2854,3 +2912,146 @@ def reset_parameters(self):

def __getitem__(self, i):
return self._networks[i]

def make_parallel(self, n: int):
"""Create a parallelized version of this network.

Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
new_networks = []
new_named_networks = {}
for net, input, output in zip(self._networks, self._inputs,
self._outputs):
pnet = alf.layers.make_parallel_net(net, n)
if not output:
new_networks.append((input, pnet))
else:
new_named_networks[output] = (input, pnet)
return Sequential(
*new_networks, output=self._output, **new_named_networks)


def make_parallel_net(module, n: int):
"""Make a parallelized version of ``module``.

A parallel network has ``n`` copies of network with the same structure but
different independently initialized parameters. The parallel network can
process a batch of the data with shape [batch_size, n, ...] using ``n``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we automatically convert data with shape [batch_size, ...] to [batch_size, n, ...]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As network can have sub-networks, doing this check for all of them can be wasteful. So it's intended to use make_parallel_input to do the conversion by the user.

networks with same structure.

If ``module`` has member function make_parallel, it will be called to make
the parallel network. Otherwise, it will creates a ``NaiveParallelLayer``,
which simply making ``n`` copies of ``module`` and use a loop to call them
in ``forward()``.

Examples:

Applying parallel net on same input:

.. code-block:: python

pnet = make_parallel_net(net, n)
# replicate input.
# pinput will have shape [batch_size, n, ...], if input has shape [batch_size, ...]
pinput = make_parallel_input(input)
emailweixu marked this conversation as resolved.
Show resolved Hide resolved
poutput = pnet(pinput)

If you already have parallel input with shape [batch_size, n, ...], you can
omit the call to ``make_parallel_input`` in the above code.

Args:
module (Network | nn.Module | Callable): the network to be parallelized.
n (int): the number of copies
Returns:
the parallelized network.
"""
if hasattr(module, 'make_parallel'):
return module.make_parallel(n)
else:
logging.warning(
"%s does not have make_parallel. A naive parallel layer "
"will be created." % str(module))
return NaiveParallelLayer(module, n)


class NaiveParallelLayer(nn.Module):
def __init__(self, module: Union[nn.Module, Callable], n: int):
"""
A parallel network has ``n`` copies of network with the same structure but
different indepently initialized parameters.

``NaiveParallelLayer`` creats ``n`` independent networks with the same
emailweixu marked this conversation as resolved.
Show resolved Hide resolved
structure as ``network`` and evaluate them separately in a loop during
``forward()``.

Args:
module (nn.Module | Callable): the parallel network will have ``n`
copies of ``module``.
n (int): ``n`` copies of ``module``
"""
super().__init__()
if isinstance(module, nn.Module):
self._networks = nn.ModuleList(
[copy.deepcopy(module) for i in range(n)])
else:
self._networks = [module] * n
self._n = n

def forward(self, inputs):
"""Compute the output.

Args:
inputs (nested torch.Tensor): its shape is ``[B, n, ...]``
Returns:
output (nested torch.Tensor): its shape is ``[B, n, ...]``
"""
outputs = []
for i in range(self._n):
inp = alf.nest.map_structure(lambda x: x[:, i, ...], inputs)
ret = self._networks[i](inp)
outputs.append(ret)
if self._n > 1:
output = alf.nest.map_structure(
lambda *tensors: torch.stack(tensors, dim=1), *outputs)
else:
output = alf.nest.map_structure(lambda tensor: tensor.unsqueeze(1),
outputs[0])

return output


def make_parallel_input(inputs, n: int):
"""Replicate ``inputs`` over dim 1 for ``n`` times so it can be processed by
parallel networks.

Args:
inputs (nested Tensor): a nest of Tensor
n (int): ``inputs`` will be replicated ``n`` times.
Returns:
inputs replicated over dim 1
"""
return map_structure(lambda x: x.unsqueeze(1).expand(-1, n, *x.shape[1:]),
inputs)


def make_parallel_spec(specs, n: int):
"""Make the spec for parallel network.

Args:
specs (nested TensorSpec): the input spec for the non-parallelized network
n (int): the number of copies of the parallelized network
Returns:
input tensor spec for the parallelized network
"""

def _make_spec(spec):
if type(spec) == alf.TensorSpec:
return alf.TensorSpec((n, ) + spec.shape, spec.dtype)
else: # BoundedTensorSpec
return alf.BoundedTensorSpec((n, ) + spec.shape, spec.dtype,
spec.minimum, spec.maximum)

return map_structure(_make_spec, specs)
Loading