From b8cb37ab1d040384d9df7a7b3db61b06dbbc22f2 Mon Sep 17 00:00:00 2001 From: emailweixu Date: Wed, 25 Aug 2021 18:30:04 -0700 Subject: [PATCH] make_parallel for network containers (#985) * make_parallel for network containers * Address comments * Address further comments --- .../ppg/disjoint_policy_value_network_test.py | 2 +- alf/layers.py | 278 +++++++++++++++--- alf/layers_test.py | 107 +++++++ alf/nest/utils.py | 26 +- .../actor_distribution_networks_test.py | 2 +- alf/networks/containers.py | 63 ++++ alf/networks/containers_test.py | 27 ++ alf/networks/network.py | 13 +- alf/networks/networks.py | 24 +- alf/networks/preprocessors_test.py | 2 +- alf/test/case.py | 5 +- alf/utils/spec_utils.py | 2 +- alf/utils/tensor_utils.py | 45 +++ 13 files changed, 528 insertions(+), 68 deletions(-) diff --git a/alf/algorithms/ppg/disjoint_policy_value_network_test.py b/alf/algorithms/ppg/disjoint_policy_value_network_test.py index c5154780a..5dd97866d 100644 --- a/alf/algorithms/ppg/disjoint_policy_value_network_test.py +++ b/alf/algorithms/ppg/disjoint_policy_value_network_test.py @@ -61,7 +61,7 @@ def test_architecture(self, is_sharing_encoder): EncodingNetwork, conv_layer_params=self._conv_layer_params, fc_layer_params=self._fc_layer_params, - preprocessing_combiner=NestConcat(dim=1)), + preprocessing_combiner=NestConcat(dim=0)), is_sharing_encoder=is_sharing_encoder) # Verify that the output specs are correct diff --git a/alf/layers.py b/alf/layers.py index 7d7deabce..c9f7d5724 100644 --- a/alf/layers.py +++ b/alf/layers.py @@ -13,14 +13,14 @@ # limitations under the License. """Some basic layers.""" +from absl import logging import copy - from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from typing import Callable +from typing import Union, Callable import alf from alf.initializers import variance_scaling_init @@ -30,6 +30,7 @@ from alf.tensor_specs import TensorSpec from alf.utils import common from alf.utils.math_ops import identity +from alf.utils.tensor_utils import BatchSquash def normalize_along_batch_dims(x, mean, variance, variance_epsilon): @@ -84,6 +85,9 @@ def __init__(self, dtype=torch.float32): def forward(self, x): return x.to(self._dtype) + def make_parallel(self, n: int): + return Cast(self._dtype) + class Transpose(nn.Module): """A layer that perform the transpose of channels. @@ -109,6 +113,19 @@ def __init__(self, dim0=0, dim1=1): def forward(self, x): return x.transpose(self._dim0, self._dim1) + def make_parallel(self, n: int): + """Create a Transpose layer to handle parallel batch. + + It is assumed that a parallel batch has shape [B, n, ...] and both the + batch dimension and replica dimension are not considered for transpose. + + Args: + n (int): the number of replicas. + Returns: + a ``Transpose`` layer to handle parallel batch. + """ + return Transpose(self._dim0, self._dim1) + class Permute(nn.Module): """A layer that perform the permutation of channels.""" @@ -127,50 +144,18 @@ def __init__(self, *dims): def forward(self, x): return x.permute(*self._dims) + def make_parallel(self, n: int): + """Create a Permute layer to handle parallel batch. -class BatchSquash(object): - """Facilitates flattening and unflattening batch dims of a tensor. Copied - from `tf_agents`. - - Exposes a pair of matched flatten and unflatten methods. After flattening - only 1 batch dimension will be left. This facilitates evaluating networks - that expect inputs to have only 1 batch dimension. - """ - - def __init__(self, batch_dims): - """Create two tied ops to flatten and unflatten the front dimensions. + It is assumed that a parallel batch has shape [B, n, ...] and both the + batch dimension and replica dimension are not considered for permute. Args: - batch_dims (int): Number of batch dimensions the flatten/unflatten - ops should handle. - - Raises: - ValueError: if batch dims is negative. + n (int): the number of replicas. + Returns: + a ``Permute`` layer to handle parallel batch. """ - if batch_dims < 0: - raise ValueError('Batch dims must be non-negative.') - self._batch_dims = batch_dims - self._original_tensor_shape = None - - def flatten(self, tensor): - """Flattens and caches the tensor's batch_dims.""" - if self._batch_dims == 1: - return tensor - self._original_tensor_shape = tensor.shape - return torch.reshape(tensor, - (-1, ) + tuple(tensor.shape[self._batch_dims:])) - - def unflatten(self, tensor): - """Unflattens the tensor's batch_dims using the cached shape.""" - if self._batch_dims == 1: - return tensor - - if self._original_tensor_shape is None: - raise ValueError('Please call flatten before unflatten.') - - return torch.reshape( - tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) + - tuple(tensor.shape[1:]))) + return Permute(*self._dims) @alf.configurable @@ -183,6 +168,9 @@ def forward(self, input): return nn.functional.one_hot( input, num_classes=self._num_classes).to(torch.float32) + def make_parallel(self, n: int): + return OneHot(self._num_classes) + @alf.configurable class FixedDecodingLayer(nn.Module): @@ -432,7 +420,7 @@ def weight(self): def bias(self): return self._bias - def make_parallel(self, n): + def make_parallel(self, n: int): """Create a ``ParallelFC`` using ``n`` replicas of ``self``. The initialized layer parameters will be different. """ @@ -639,6 +627,8 @@ def __init__(self, bias_init_value (float): a constant """ super().__init__() + self._input_size = input_size + self._output_size = output_size self._activation = activation self._weight = nn.Parameter(torch.Tensor(n, output_size, input_size)) if use_bias: @@ -1085,6 +1075,11 @@ def __init__(self, ``kernel_initializer`` is not None. bias_init_value (float): a constant """ + # get the argument list with vals + self._kwargs = copy.deepcopy(locals()) + self._kwargs.pop('self') + self._kwargs.pop('__class__') + super(Conv2D, self).__init__() if use_bias is None: use_bias = not use_bn @@ -1136,6 +1131,9 @@ def weight(self): def bias(self): return self._conv2d.bias + def make_parallel(self, n: int): + return ParallelConv2D(n=n, **self._kwargs) + @alf.configurable class Conv2DBatchEnsemble(Conv2D): @@ -1490,6 +1488,10 @@ def __init__(self, ``kernel_initializer`` is not None. bias_init_value (float): a constant """ + # get the argument list with vals + self._kwargs = copy.deepcopy(locals()) + self._kwargs.pop('self') + self._kwargs.pop('__class__') super(ConvTranspose2D, self).__init__() if use_bias is None: use_bias = not use_bn @@ -1533,6 +1535,9 @@ def weight(self): def bias(self): return self._conv_trans2d.bias + def make_parallel(self, n: int): + return ParallelConvTranspose2D(n=n, **self._kwargs) + @alf.configurable class ParallelConvTranspose2D(nn.Module): @@ -1544,6 +1549,7 @@ def __init__(self, activation=torch.relu_, strides=1, padding=0, + output_padding=0, use_bias=None, use_bn=False, kernel_initializer=None, @@ -1560,6 +1566,9 @@ def __init__(self, activation (torch.nn.functional): strides (int or tuple): padding (int or tuple): + output_padding (int or tuple): Additional size added to one side of + each dimension in the output shape. Default: 0. See pytorch + documentation for more detail. use_bias (bool|None): If None, will use ``not use_bn`` use_bn (bool): kernel_initializer (Callable): initializer for the conv_trans layer. @@ -1585,6 +1594,7 @@ def __init__(self, groups=n, stride=strides, padding=padding, + output_padding=output_padding, bias=use_bias) for i in range(n): @@ -2074,12 +2084,8 @@ def __init__(self, shape): def forward(self, x): return x.reshape(x.shape[0], *self._shape) - -def _tuplify2d(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - return (x, x) + def make_parallel(self, n: int): + return Reshape((n, ) + self._shape) def _conv_transpose_2d(in_channels, @@ -2604,6 +2610,9 @@ def forward(self, input): return alf.nest.map_structure( lambda path: alf.nest.get_field(input, path), self._fields) + def make_parallel(self, n: int): + return GetFields(self._fields) + class Sum(nn.Module): """Sum over given dimension(s). @@ -2624,6 +2633,19 @@ def __init__(self, dim): def forward(self, input): return input.sum(dim=self._dim) + def make_parallel(self, n: int): + """Create a Sum layer to handle parallel batch. + + It is assumed that a parallel batch has shape [B, n, ...] and both the + batch dimension and replica dimension are not counted for ``dim`` + + Args: + n (int): the number of replicas. + Returns: + a ``Sum`` layer to handle parallel batch. + """ + return Sum(self._dim) + def reset_parameters(module): """Reset the parameters for ``module``. @@ -2655,6 +2677,9 @@ def __init__(self): def forward(self, input): return common.detach(input) + def make_parallel(self, n: int): + return Detach() + class Branch(nn.Module): """Apply multiple modules on the same input. @@ -2710,6 +2735,18 @@ def forward(self, inputs): def reset_parameters(self): alf.nest.map_structure(reset_parameters, self._networks) + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + new_networks = alf.nest.map_structure( + lambda net: make_parallel_net(net, n), self._networks) + return Branch(new_networks) + class Sequential(nn.Module): """A more flexible Sequential than torch.nn.Sequential. @@ -2854,3 +2891,146 @@ def reset_parameters(self): def __getitem__(self, i): return self._networks[i] + + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + new_networks = [] + new_named_networks = {} + for net, input, output in zip(self._networks, self._inputs, + self._outputs): + pnet = alf.layers.make_parallel_net(net, n) + if not output: + new_networks.append((input, pnet)) + else: + new_named_networks[output] = (input, pnet) + return Sequential( + *new_networks, output=self._output, **new_named_networks) + + +def make_parallel_net(module, n: int): + """Make a parallelized version of ``module``. + + A parallel network has ``n`` copies of network with the same structure but + different independently initialized parameters. The parallel network can + process a batch of the data with shape [batch_size, n, ...] using ``n`` + networks with same structure. + + If ``module`` has member function make_parallel, it will be called to make + the parallel network. Otherwise, it will creates a ``NaiveParallelLayer``, + which simply making ``n`` copies of ``module`` and use a loop to call them + in ``forward()``. + + Examples: + + Applying parallel net on same input: + + .. code-block:: python + + pnet = make_parallel_net(net, n) + # replicate input. + # pinput will have shape [batch_size, n, ...], if input has shape [batch_size, ...] + pinput = make_parallel_input(input, n) + poutput = pnet(pinput) + + If you already have parallel input with shape [batch_size, n, ...], you can + omit the call to ``make_parallel_input`` in the above code. + + Args: + module (Network | nn.Module | Callable): the network to be parallelized. + n (int): the number of copies + Returns: + the parallelized network. + """ + if hasattr(module, 'make_parallel'): + return module.make_parallel(n) + else: + logging.warning( + "%s does not have make_parallel. A naive parallel layer " + "will be created." % str(module)) + return NaiveParallelLayer(module, n) + + +class NaiveParallelLayer(nn.Module): + def __init__(self, module: Union[nn.Module, Callable], n: int): + """ + A parallel network has ``n`` copies of network with the same structure but + different indepently initialized parameters. + + ``NaiveParallelLayer`` creates ``n`` independent networks with the same + structure as ``network`` and evaluate them separately in a loop during + ``forward()``. + + Args: + module (nn.Module | Callable): the parallel network will have ``n` + copies of ``module``. + n (int): ``n`` copies of ``module`` + """ + super().__init__() + if isinstance(module, nn.Module): + self._networks = nn.ModuleList( + [copy.deepcopy(module) for i in range(n)]) + else: + self._networks = [module] * n + self._n = n + + def forward(self, inputs): + """Compute the output. + + Args: + inputs (nested torch.Tensor): its shape is ``[B, n, ...]`` + Returns: + output (nested torch.Tensor): its shape is ``[B, n, ...]`` + """ + outputs = [] + for i in range(self._n): + inp = alf.nest.map_structure(lambda x: x[:, i, ...], inputs) + ret = self._networks[i](inp) + outputs.append(ret) + if self._n > 1: + output = alf.nest.map_structure( + lambda *tensors: torch.stack(tensors, dim=1), *outputs) + else: + output = alf.nest.map_structure(lambda tensor: tensor.unsqueeze(1), + outputs[0]) + + return output + + +def make_parallel_input(inputs, n: int): + """Replicate ``inputs`` over dim 1 for ``n`` times so it can be processed by + parallel networks. + + Args: + inputs (nested Tensor): a nest of Tensor + n (int): ``inputs`` will be replicated ``n`` times. + Returns: + inputs replicated over dim 1 + """ + return map_structure(lambda x: x.unsqueeze(1).expand(-1, n, *x.shape[1:]), + inputs) + + +def make_parallel_spec(specs, n: int): + """Make the spec for parallel network. + + Args: + specs (nested TensorSpec): the input spec for the non-parallelized network + n (int): the number of copies of the parallelized network + Returns: + input tensor spec for the parallelized network + """ + + def _make_spec(spec): + if type(spec) == alf.TensorSpec: + return alf.TensorSpec((n, ) + spec.shape, spec.dtype) + else: # BoundedTensorSpec + return alf.BoundedTensorSpec((n, ) + spec.shape, spec.dtype, + spec.minimum, spec.maximum) + + return map_structure(_make_spec, specs) diff --git a/alf/layers_test.py b/alf/layers_test.py index 394016a84..673a09e22 100644 --- a/alf/layers_test.py +++ b/alf/layers_test.py @@ -23,6 +23,31 @@ class LayersTest(parameterized.TestCase, alf.test.TestCase): + def _test_make_parallel( + self, + net, + spec, + tolerance=1e-6, + get_pnet_parameters=lambda pnet: pnet.parameters()): + batch_size = 10 + for n in (1, 2, 5): + pnet = net.make_parallel(n) + nnet = alf.layers.NaiveParallelLayer(net, n) + for i in range(n): + for pp, np in zip( + get_pnet_parameters(pnet), + nnet._networks[i].parameters()): + self.assertEqual(pp.shape, (n, ) + np.shape) + np.data.copy_(pp[i]) + pspec = alf.layers.make_parallel_spec(spec, n) + input = alf.nest.map_structure(lambda s: s.sample([batch_size]), + pspec) + presult = pnet(input) + nresult = nnet(input) + alf.nest.map_structure( + lambda p, n: self.assertTensorClose(p, n, tolerance), presult, + nresult) + @parameterized.parameters( dict(n=1, act=torch.relu, use_bias=False, parallel_x=False), dict(n=1, act=math_ops.identity, use_bias=False, parallel_x=False), @@ -881,6 +906,9 @@ def test_sequential1(self): x4 = net[3]((x2, x3, x)) self.assertEqual(x4, y) + input_spec = alf.BoundedTensorSpec((4, )) + self._test_make_parallel(net, input_spec) + def test_sequential2(self): # test wrong field name net = alf.layers.Sequential( @@ -914,6 +942,9 @@ def test_sequential3(self): self.assertEqual(y['b'], b) self.assertEqual(y['c'], c) + input_spec = alf.BoundedTensorSpec((4, )) + self._test_make_parallel(net, input_spec) + def test_sequential4(self): # test output net = alf.layers.Sequential( @@ -932,6 +963,82 @@ def test_sequential4(self): c = a + b self.assertEqual(c, y) + input_spec = alf.BoundedTensorSpec((4, )) + self._test_make_parallel(net, input_spec) + + def test_branch(self): + net = alf.layers.Branch(alf.layers.FC(4, 6), alf.layers.FC(4, 8)) + input_spec = alf.BoundedTensorSpec((4, )) + self._test_make_parallel(net, input_spec) + + def test_fc(self): + input_spec = alf.BoundedTensorSpec((5, )) + layer = alf.layers.FC(5, 7) + self._test_make_parallel(layer, input_spec) + + def test_conv2d(self): + input_spec = alf.BoundedTensorSpec((3, 10, 10)) + layer = alf.layers.Conv2D(3, 5, 3) + self._test_make_parallel( + layer, + input_spec, + get_pnet_parameters=lambda pnet: (pnet.weight, pnet.bias)) + + def test_conv_transpose_2d(self): + input_spec = alf.BoundedTensorSpec((3, 10, 10)) + layer = alf.layers.ConvTranspose2D(3, 5, 3) + self._test_make_parallel( + layer, + input_spec, + get_pnet_parameters=lambda pnet: (pnet.weight, pnet.bias)) + + def test_cast(self): + input_spec = alf.BoundedTensorSpec((8, ), dtype=torch.uint8) + layer = alf.layers.Cast() + self._test_make_parallel(layer, input_spec) + + def test_transpose(self): + input_spec = alf.BoundedTensorSpec((8, 4, 10)) + layer = alf.layers.Transpose() + self._test_make_parallel(layer, input_spec) + layer = alf.layers.Transpose(0, 2) + self._test_make_parallel(layer, input_spec) + layer = alf.layers.Transpose(-2, -1) + self._test_make_parallel(layer, input_spec) + + def test_permute(self): + input_spec = alf.BoundedTensorSpec((8, 4, 10)) + layer = alf.layers.Permute(2, 1, 0) + self._test_make_parallel(layer, input_spec) + + def test_onehot(self): + input_spec = alf.BoundedTensorSpec((10, ), + dtype=torch.int64, + minimum=0, + maximum=11) + layer = alf.layers.OneHot(12) + self._test_make_parallel(layer, input_spec) + + def test_reshape(self): + input_spec = alf.BoundedTensorSpec((8, 4, 10)) + layer = alf.layers.Reshape((32, 10)) + self._test_make_parallel(layer, input_spec) + + def test_get_fields(self): + input_spec = dict( + a=alf.BoundedTensorSpec((8, 4, 10)), + b=alf.BoundedTensorSpec((4, )), + c=alf.BoundedTensorSpec((3, ))) + layer = alf.layers.GetFields(('a', 'c')) + self._test_make_parallel(layer, input_spec) + + def test_sum(self): + input_spec = alf.BoundedTensorSpec((8, 4, 10)) + layer = alf.layers.Sum(dim=1) + self._test_make_parallel(layer, input_spec) + layer = alf.layers.Sum(dim=-1) + self._test_make_parallel(layer, input_spec) + if __name__ == "__main__": alf.test.main() diff --git a/alf/nest/utils.py b/alf/nest/utils.py index 1827b1393..e81dc14c3 100644 --- a/alf/nest/utils.py +++ b/alf/nest/utils.py @@ -78,6 +78,9 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"): It assumes that all the selected tensors have the same tensor spec. Can be used as a preprocessing combiner in ``EncodingNetwork``. + Note that batch dimension is not considered for concat. This means that + dim=0 means the first dimension after batch dimension. + Args: nest_mask (nest|None): nest structured mask indicating which of the tensors in the nest to be selected or not, indicated by a @@ -88,8 +91,9 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"): name (str): """ super(NestConcat, self).__init__(name) + self._nest_mask = nest_mask self._flat_mask = nest.flatten(nest_mask) if nest_mask else nest_mask - self._dim = dim + self._dim = dim if dim < 0 else dim + 1 def _combine_flat(self, tensors): if self._flat_mask is not None: @@ -104,6 +108,19 @@ def _combine_flat(self, tensors): else: return torch.cat(tensors, dim=self._dim) + def make_parallel(self, n): + """Create a ``NestConcat`` layer to handle parallel batch. + + It is assumed that a parallel batch has shape [B, n, ...] and both the + batch dimension and replica dimension are not considered for concat. + + Args: + n (int): the number of replicas. + Returns: + a ``NestConcat`` layer to handle parallel batch. + """ + return NestConcat(self._nest_mask, self._dim, "parallel_" + self._name) + @alf.configurable class NestSum(NestCombiner): @@ -129,6 +146,10 @@ def _combine_flat(self, tensors): ret *= 1 / float(len(tensors)) return self._activation(ret) + def make_parallel(self, n): + return NestSum(self._average, self._activation, + "parallel_" + self._name) + @alf.configurable class NestMultiply(NestCombiner): @@ -151,6 +172,9 @@ def _combine_flat(self, tensors): ret = alf.utils.math_ops.mul_n(tensors) return self._activation(ret) + def make_parallel(self, n): + return NestMultiply(self._activation, "parallel_" + self._name) + def stack_nests(nests, dim=0): """Stack tensors to a sequence. diff --git a/alf/networks/actor_distribution_networks_test.py b/alf/networks/actor_distribution_networks_test.py index 7921d0e1e..40957b0ca 100644 --- a/alf/networks/actor_distribution_networks_test.py +++ b/alf/networks/actor_distribution_networks_test.py @@ -42,7 +42,7 @@ def setUp(self): self._conv_layer_params = ((8, 3, 1), (16, 3, 2, 1)) self._fc_layer_params = (100, ) self._input_preprocessors = [torch.tanh, None] - self._preprocessing_combiner = NestConcat(dim=1) + self._preprocessing_combiner = NestConcat(dim=0) def _init(self, lstm_hidden_size): if lstm_hidden_size is not None: diff --git a/alf/networks/containers.py b/alf/networks/containers.py index 900f62865..3d587fad4 100644 --- a/alf/networks/containers.py +++ b/alf/networks/containers.py @@ -23,6 +23,7 @@ from alf.nest.utils import get_nested_field from alf.utils.spec_utils import is_same_spec from .network import Network, get_input_tensor_spec, wrap_as_network +from alf.layers import make_parallel_spec def Sequential(*modules, @@ -253,6 +254,27 @@ def copy(self, name=None): def __getitem__(self, i): return self._networks[i] + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + new_networks = [] + new_named_networks = {} + for net, input, output in zip(self._networks, self._inputs, + self._outputs): + pnet = alf.layers.make_parallel_net(net, n) + if not output: + new_networks.append((input, pnet)) + else: + new_named_networks[output] = (input, pnet) + input_spec = make_parallel_spec(self._input_tensor_spec, n) + return _Sequential(new_networks, new_named_networks, self._output, + input_spec, "parallel_" + self.name) + class Parallel(Network): """Apply each Network in the nest of Network to the corresponding input. @@ -343,6 +365,19 @@ def copy(self, name=None): def networks(self): return self._networks + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + networks = map_structure( + lambda net: alf.layers.make_parallel_net(net, n), self._networks) + input_spec = make_parallel_spec(self._input_tensor_spec, n) + return Parallel(networks, input_spec, 'parallel_' + self.name) + def Branch(*modules, input_tensor_spec=None, name="Branch", **named_modules): """Apply multiple networks on the same input. @@ -456,6 +491,22 @@ def copy(self, name=None): def networks(self): return self._networks + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + networks = map_structure( + lambda net: alf.layers.make_parallel_net(net, n), self._networks) + input_spec = make_parallel_spec(self._input_tensor_spec, n) + return Branch( + networks, + input_tensor_spec=input_spec, + name='parallel_' + self.name) + class Echo(Network): """Echo network. @@ -567,3 +618,15 @@ def forward(self, input, state): real_output = block_output['output'] echo_output = block_output['echo'] return real_output, (block_state, echo_output) + + def make_parallel(self, n: int): + """Create a parallelized version of this network. + + Args: + n (int): the number of copies + Returns: + the parallelized version of this network + """ + return Echo( + alf.layers.make_parallel_net(self._block), + make_parallel_spec(self._input_tensor_spec, n)) diff --git a/alf/networks/containers_test.py b/alf/networks/containers_test.py index 6b053fb16..cd95a823f 100644 --- a/alf/networks/containers_test.py +++ b/alf/networks/containers_test.py @@ -24,6 +24,28 @@ def _randn_from_spec(specs, batch_size): class ContainersTest(alf.test.TestCase): + def _test_make_parallel(self, net, tolerance=1e-5): + batch_size = 10 + spec = net.input_tensor_spec + for n in (1, 2, 5): + pnet = net.make_parallel(n) + nnet = alf.nn.NaiveParallelNetwork(net, n) + for i in range(n): + for pp, np in zip(pnet.parameters(), + nnet._networks[i].parameters()): + self.assertEqual(pp.shape, (n, ) + np.shape) + np.data.copy_(pp[i]) + pspec = alf.layers.make_parallel_spec(spec, n) + input = _randn_from_spec(pspec, batch_size) + presult = pnet(input) + nresult = nnet(input) + alf.nest.map_structure( + lambda p, n: self.assertEqual(p.shape, n.shape), presult, + nresult) + alf.nest.map_structure( + lambda p, n: self.assertTensorClose(p, n, tolerance), presult, + nresult) + def _verify_parameter_copy(self, src, copy): """net.copy() only copy the structure, not the values of parameters.""" for s, c in zip(src.parameters(), copy.parameters()): @@ -112,6 +134,7 @@ def test_sequential2(self): net_copy = net.copy() self._verify_parameter_copy(net, net_copy) + self._test_make_parallel(net) def test_sequential_complex2(self): net = alf.nn.Sequential( @@ -136,6 +159,7 @@ def test_sequential_complex2(self): net_copy = net.copy() self._verify_parameter_copy(net, net_copy) + self._test_make_parallel(net) def test_sequential_complex3(self): net = alf.nn.Sequential( @@ -160,6 +184,7 @@ def test_sequential_complex3(self): net_copy = net.copy() self._verify_parameter_copy(net, net_copy) + self._test_make_parallel(net) def test_parallel1(self): net = alf.nn.Parallel((alf.layers.FC(4, 6), alf.nn.GRUCell(6, 8), @@ -213,6 +238,7 @@ def test_parallel2(self): net_copy = net.copy() self._verify_parameter_copy(net, net_copy) + self._test_make_parallel(net) def test_branch1(self): net = alf.nn.Branch((alf.layers.FC(4, 6), alf.nn.GRUCell(4, 8), @@ -262,6 +288,7 @@ def test_branch2(self): net_copy = net.copy() self._verify_parameter_copy(net, net_copy) + self._test_make_parallel(net) if __name__ == '__main__': diff --git a/alf/networks/network.py b/alf/networks/network.py index 3bca52d61..f6d806c7b 100644 --- a/alf/networks/network.py +++ b/alf/networks/network.py @@ -218,7 +218,7 @@ def make_parallel(self, n): """Make a parallelized version of this network. A parallel network has ``n`` copies of network with the same structure but - different indepently initialized parameters. + different independently initialized parameters. By default, it creates ``NaiveParallelNetwork``, which simply making ``n`` copies of this network and use a loop to call them in ``forward()``. @@ -293,13 +293,13 @@ def forward(self, inputs, state=()): inp = alf.nest.map_structure(lambda x: x[:, i, ...], inputs) s = alf.nest.map_structure(lambda x: x[:, i, ...], state) ret = self._networks[i](inp, s) - ret = alf.nest.map_structure(lambda x: x.unsqueeze(1), ret) output_states.append(ret) if self._n > 1: output, new_state = alf.nest.map_structure( - lambda *tensors: torch.cat(tensors, dim=1), *output_states) + lambda *tensors: torch.stack(tensors, dim=1), *output_states) else: - output, new_state = output_states[0] + output, new_state = alf.nest.map_structure( + lambda x: x.unsqueeze(1), output_states[0]) return output, new_state @@ -332,6 +332,11 @@ def copy(self): return NetworkWrapper(module, self._input_tensor_spec) + def make_parallel(self, n: int): + return NetworkWrapper( + alf.layers.make_parallel_net(self._module, n), + alf.layers.make_parallel_spec(self.input_tensor_spec, n)) + def get_input_tensor_spec(net): """Get the input_tensor_spec of net if possible diff --git a/alf/networks/networks.py b/alf/networks/networks.py index b237e3722..6a45cb09a 100644 --- a/alf/networks/networks.py +++ b/alf/networks/networks.py @@ -41,7 +41,7 @@ class LSTMCell(Network): where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. """ - def __init__(self, input_size, hidden_size): + def __init__(self, input_size, hidden_size, name='LSTMCell'): """ Args: input_size (int): The number of expected features in the input `x` @@ -51,7 +51,8 @@ def __init__(self, input_size, hidden_size): alf.TensorSpec((hidden_size, ))) super().__init__( input_tensor_spec=alf.TensorSpec((input_size, )), - state_spec=state_spec) + state_spec=state_spec, + name=name) self._cell = nn.LSTMCell( input_size=input_size, hidden_size=hidden_size) @@ -75,7 +76,7 @@ class GRUCell(Network): where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. """ - def __init__(self, input_size, hidden_size): + def __init__(self, input_size, hidden_size, name='GRUCell'): """ Args: input_size (int): The number of expected features in the input `x` @@ -83,7 +84,8 @@ def __init__(self, input_size, hidden_size): """ super().__init__( input_tensor_spec=alf.TensorSpec((input_size, )), - state_spec=alf.TensorSpec((hidden_size, ))) + state_spec=alf.TensorSpec((hidden_size, )), + name=name) self._cell = nn.GRUCell(input_size, hidden_size) def forward(self, input, state): @@ -97,7 +99,11 @@ class Residue(Network): It performs ``y = activation(x + block(x))``. """ - def __init__(self, block, input_tensor_spec=None, activation=torch.relu_): + def __init__(self, + block, + input_tensor_spec=None, + activation=torch.relu_, + name='Residue'): """ Args: block (Callable): @@ -108,7 +114,8 @@ def __init__(self, block, input_tensor_spec=None, activation=torch.relu_): block = wrap_as_network(block, input_tensor_spec) super().__init__( input_tensor_spec=block.input_tensor_spec, - state_spec=block.state_spec) + state_spec=block.state_spec, + name='Residue') self._block = block self._activation = activation @@ -171,7 +178,8 @@ def __init__(self, stack_size, pooling_size=1, dtype=torch.float32, - mode='skip'): + mode='skip', + name='TemporalPool'): """ Args: input_size (int|tuple[int]): shape of the input @@ -218,7 +226,7 @@ def __init__(self, state_spec = (alf.TensorSpec(shape, input_tensor_spec.dtype), pool_state_spec, alf.TensorSpec((), dtype=torch.int64)) - super().__init__(input_tensor_spec, state_spec=state_spec) + super().__init__(input_tensor_spec, state_spec=state_spec, name=name) def forward(self, x, state): if self._pooling_size == 1: diff --git a/alf/networks/preprocessors_test.py b/alf/networks/preprocessors_test.py index 484a4c27d..75bb182bc 100644 --- a/alf/networks/preprocessors_test.py +++ b/alf/networks/preprocessors_test.py @@ -86,7 +86,7 @@ def _check_with_shared_param(net1, net2, shared_subnet=None): TestInputpreprocessor.input_spec ], input_preprocessors=[input_preprocessor, torch.relu], - preprocessing_combiner=NestConcat(dim=1)) + preprocessing_combiner=NestConcat(dim=0)) # 2) test copied network has its own parameters, including # parameters from input preprocessors diff --git a/alf/test/case.py b/alf/test/case.py index 7e598eb81..e9ad555ad 100644 --- a/alf/test/case.py +++ b/alf/test/case.py @@ -43,8 +43,9 @@ def assertTensorClose(self, t1, t2, epsilon=1e-6, msg=None): 'First argument is not a Tensor') self.assertIsInstance(t2, torch.Tensor, 'Second argument is not a Tensor') - if not (torch.max(torch.abs(t1 - t2)) < epsilon): - standardMsg = '%s is not close to %s' % (t1, t2) + diff = torch.max(torch.abs(t1 - t2)) + if not (diff <= epsilon): + standardMsg = '%s is not close to %s. diff=%s' % (t1, t2, diff) self.fail(self._formatMessage(msg, standardMsg)) def assertTensorNotClose(self, t1, t2, epsilon=1e-6, msg=None): diff --git a/alf/utils/spec_utils.py b/alf/utils/spec_utils.py index 9efcf9aa6..56ce511dd 100644 --- a/alf/utils/spec_utils.py +++ b/alf/utils/spec_utils.py @@ -17,11 +17,11 @@ import torch from typing import Iterable -from alf.layers import BatchSquash import alf.nest as nest from alf.nest.utils import get_outer_rank from alf.tensor_specs import TensorSpec, BoundedTensorSpec from . import dist_utils +from alf.utils.tensor_utils import BatchSquash def spec_means_and_magnitudes(spec: BoundedTensorSpec): diff --git a/alf/utils/tensor_utils.py b/alf/utils/tensor_utils.py index e9a71fa13..86189311c 100644 --- a/alf/utils/tensor_utils.py +++ b/alf/utils/tensor_utils.py @@ -350,3 +350,48 @@ def scale_gradient(tensor, scale, clone_input=True): output = tensor output.register_hook(lambda grad: grad * scale) return output + + +class BatchSquash(object): + """Facilitates flattening and unflattening batch dims of a tensor. Copied + from `tf_agents`. + + Exposes a pair of matched flatten and unflatten methods. After flattening + only 1 batch dimension will be left. This facilitates evaluating networks + that expect inputs to have only 1 batch dimension. + """ + + def __init__(self, batch_dims): + """Create two tied ops to flatten and unflatten the front dimensions. + + Args: + batch_dims (int): Number of batch dimensions the flatten/unflatten + ops should handle. + + Raises: + ValueError: if batch dims is negative. + """ + if batch_dims < 0: + raise ValueError('Batch dims must be non-negative.') + self._batch_dims = batch_dims + self._original_tensor_shape = None + + def flatten(self, tensor): + """Flattens and caches the tensor's batch_dims.""" + if self._batch_dims == 1: + return tensor + self._original_tensor_shape = tensor.shape + return torch.reshape(tensor, + (-1, ) + tuple(tensor.shape[self._batch_dims:])) + + def unflatten(self, tensor): + """Unflattens the tensor's batch_dims using the cached shape.""" + if self._batch_dims == 1: + return tensor + + if self._original_tensor_shape is None: + raise ValueError('Please call flatten before unflatten.') + + return torch.reshape( + tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) + + tuple(tensor.shape[1:])))