Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_parallel for network containers #985

Merged
merged 3 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alf/algorithms/ppg/disjoint_policy_value_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_architecture(self, is_sharing_encoder):
EncodingNetwork,
conv_layer_params=self._conv_layer_params,
fc_layer_params=self._fc_layer_params,
preprocessing_combiner=NestConcat(dim=1)),
preprocessing_combiner=NestConcat(dim=0)),
is_sharing_encoder=is_sharing_encoder)

# Verify that the output specs are correct
Expand Down
83 changes: 31 additions & 52 deletions alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Callable
from typing import Union, Callable

import alf
from alf.initializers import variance_scaling_init
Expand All @@ -30,6 +30,7 @@
from alf.tensor_specs import TensorSpec
from alf.utils import common
from alf.utils.math_ops import identity
from alf.utils.spec_utils import BatchSquash


def normalize_along_batch_dims(x, mean, variance, variance_epsilon):
Expand Down Expand Up @@ -113,6 +114,16 @@ def forward(self, x):
return x.transpose(self._dim0, self._dim1)

def make_parallel(self, n: int):
emailweixu marked this conversation as resolved.
Show resolved Hide resolved
"""Create a Transpose layer to handle parallel batch.

It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for transpose.

Args:
n (int): the number of replicas.
Returns:
a ``Transpose`` layer to handle parallel batch.
"""
return Transpose(self._dim0, self._dim1)


Expand All @@ -134,52 +145,17 @@ def forward(self, x):
return x.permute(*self._dims)

def make_parallel(self, n: int):
return Permute(*self._dims)

"""Create a Permute layer to handle parallel batch.

class BatchSquash(object):
"""Facilitates flattening and unflattening batch dims of a tensor. Copied
from `tf_agents`.

Exposes a pair of matched flatten and unflatten methods. After flattening
only 1 batch dimension will be left. This facilitates evaluating networks
that expect inputs to have only 1 batch dimension.
"""

def __init__(self, batch_dims):
"""Create two tied ops to flatten and unflatten the front dimensions.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for permute.

Args:
batch_dims (int): Number of batch dimensions the flatten/unflatten
ops should handle.

Raises:
ValueError: if batch dims is negative.
n (int): the number of replicas.
Returns:
a ``Permute`` layer to handle parallel batch.
"""
if batch_dims < 0:
raise ValueError('Batch dims must be non-negative.')
self._batch_dims = batch_dims
self._original_tensor_shape = None

def flatten(self, tensor):
"""Flattens and caches the tensor's batch_dims."""
if self._batch_dims == 1:
return tensor
self._original_tensor_shape = tensor.shape
return torch.reshape(tensor,
(-1, ) + tuple(tensor.shape[self._batch_dims:]))

def unflatten(self, tensor):
"""Unflattens the tensor's batch_dims using the cached shape."""
if self._batch_dims == 1:
return tensor

if self._original_tensor_shape is None:
raise ValueError('Please call flatten before unflatten.')

return torch.reshape(
tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) +
tuple(tensor.shape[1:])))
return Permute(*self._dims)


@alf.configurable
Expand Down Expand Up @@ -2112,13 +2088,6 @@ def make_parallel(self, n: int):
return Reshape((n, ) + self._shape)


def _tuplify2d(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
return (x, x)


def _conv_transpose_2d(in_channels,
out_channels,
kernel_size,
Expand Down Expand Up @@ -2665,6 +2634,16 @@ def forward(self, input):
return input.sum(dim=self._dim)

def make_parallel(self, n: int):
"""Create a Sum layer to handle parallel batch.

It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are counted for ``dim``
emailweixu marked this conversation as resolved.
Show resolved Hide resolved

Args:
n (int): the number of replicas.
Returns:
a ``Sum`` layer to handle parallel batch.
"""
return Sum(self._dim)
emailweixu marked this conversation as resolved.
Show resolved Hide resolved


Expand Down Expand Up @@ -2956,7 +2935,7 @@ def make_parallel_net(module, n: int):
pnet = make_parallel_net(net, n)
# replicate input.
# pinput will have shape [batch_size, n, ...], if input has shape [batch_size, ...]
pinput = make_parallel_input(input)
pinput = make_parallel_input(input, n)
poutput = pnet(pinput)

If you already have parallel input with shape [batch_size, n, ...], you can
Expand All @@ -2983,7 +2962,7 @@ def __init__(self, module: Union[nn.Module, Callable], n: int):
A parallel network has ``n`` copies of network with the same structure but
different indepently initialized parameters.

``NaiveParallelLayer`` creats ``n`` independent networks with the same
``NaiveParallelLayer`` creates ``n`` independent networks with the same
structure as ``network`` and evaluate them separately in a loop during
``forward()``.

Expand Down
2 changes: 1 addition & 1 deletion alf/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def test_permute(self):
layer = alf.layers.Permute(2, 1, 0)
self._test_make_parallel(layer, input_spec)

def test_onehost(self):
def test_onehot(self):
input_spec = alf.BoundedTensorSpec((10, ),
dtype=torch.int64,
minimum=0,
Expand Down
18 changes: 15 additions & 3 deletions alf/nest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"):
It assumes that all the selected tensors have the same tensor spec.
Can be used as a preprocessing combiner in ``EncodingNetwork``.

Note that batch dimension is not considered for concat. This means that
dim=0 means the first dimension after batch dimension.

Args:
nest_mask (nest|None): nest structured mask indicating which of the
tensors in the nest to be selected or not, indicated by a
Expand All @@ -90,7 +93,7 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"):
super(NestConcat, self).__init__(name)
self._nest_mask = nest_mask
self._flat_mask = nest.flatten(nest_mask) if nest_mask else nest_mask
self._dim = dim
self._dim = dim if dim < 0 else dim + 1

def _combine_flat(self, tensors):
if self._flat_mask is not None:
Expand All @@ -106,8 +109,17 @@ def _combine_flat(self, tensors):
return torch.cat(tensors, dim=self._dim)

def make_parallel(self, n):
dim = self._dim if self._dim < 0 else self._dim + 1
return NestConcat(self._nest_mask, dim, "parallel_" + self._name)
"""Create a NestConcat layer to handle parallel batch.

It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for concat.

Args:
n (int): the number of replicas.
Returns:
a ``Transpose`` layer to handle parallel batch.
"""
return NestConcat(self._nest_mask, self._dim, "parallel_" + self._name)


@alf.configurable
Expand Down
2 changes: 1 addition & 1 deletion alf/networks/actor_distribution_networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):
self._conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
self._fc_layer_params = (100, )
self._input_preprocessors = [torch.tanh, None]
self._preprocessing_combiner = NestConcat(dim=1)
self._preprocessing_combiner = NestConcat(dim=0)

def _init(self, lstm_hidden_size):
if lstm_hidden_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion alf/networks/preprocessors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _check_with_shared_param(net1, net2, shared_subnet=None):
TestInputpreprocessor.input_spec
],
input_preprocessors=[input_preprocessor, torch.relu],
preprocessing_combiner=NestConcat(dim=1))
preprocessing_combiner=NestConcat(dim=0))

# 2) test copied network has its own parameters, including
# parameters from input preprocessors
Expand Down
46 changes: 45 additions & 1 deletion alf/utils/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,57 @@
import torch
from typing import Iterable

from alf.layers import BatchSquash
import alf.nest as nest
from alf.nest.utils import get_outer_rank
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from . import dist_utils


class BatchSquash(object):
emailweixu marked this conversation as resolved.
Show resolved Hide resolved
"""Facilitates flattening and unflattening batch dims of a tensor. Copied
from `tf_agents`.

Exposes a pair of matched flatten and unflatten methods. After flattening
only 1 batch dimension will be left. This facilitates evaluating networks
that expect inputs to have only 1 batch dimension.
"""

def __init__(self, batch_dims):
"""Create two tied ops to flatten and unflatten the front dimensions.

Args:
batch_dims (int): Number of batch dimensions the flatten/unflatten
ops should handle.

Raises:
ValueError: if batch dims is negative.
"""
if batch_dims < 0:
raise ValueError('Batch dims must be non-negative.')
self._batch_dims = batch_dims
self._original_tensor_shape = None

def flatten(self, tensor):
"""Flattens and caches the tensor's batch_dims."""
if self._batch_dims == 1:
return tensor
self._original_tensor_shape = tensor.shape
return torch.reshape(tensor,
(-1, ) + tuple(tensor.shape[self._batch_dims:]))

def unflatten(self, tensor):
"""Unflattens the tensor's batch_dims using the cached shape."""
if self._batch_dims == 1:
return tensor

if self._original_tensor_shape is None:
raise ValueError('Please call flatten before unflatten.')

return torch.reshape(
tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) +
tuple(tensor.shape[1:])))


def spec_means_and_magnitudes(spec: BoundedTensorSpec):
"""Get the center and magnitude of the ranges for the input spec.

Expand Down