Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Aug 25, 2021
1 parent d9c560b commit 438ca21
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 60 deletions.
2 changes: 1 addition & 1 deletion alf/algorithms/ppg/disjoint_policy_value_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_architecture(self, is_sharing_encoder):
EncodingNetwork,
conv_layer_params=self._conv_layer_params,
fc_layer_params=self._fc_layer_params,
preprocessing_combiner=NestConcat(dim=1)),
preprocessing_combiner=NestConcat(dim=0)),
is_sharing_encoder=is_sharing_encoder)

# Verify that the output specs are correct
Expand Down
83 changes: 31 additions & 52 deletions alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Callable
from typing import Union, Callable

import alf
from alf.initializers import variance_scaling_init
Expand All @@ -30,6 +30,7 @@
from alf.tensor_specs import TensorSpec
from alf.utils import common
from alf.utils.math_ops import identity
from alf.utils.spec_utils import BatchSquash


def normalize_along_batch_dims(x, mean, variance, variance_epsilon):
Expand Down Expand Up @@ -113,6 +114,16 @@ def forward(self, x):
return x.transpose(self._dim0, self._dim1)

def make_parallel(self, n: int):
"""Create a Transpose layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for transpose.
Args:
n (int): the number of replicas.
Returns:
a ``Transpose`` layer to handle parallel batch.
"""
return Transpose(self._dim0, self._dim1)


Expand All @@ -134,52 +145,17 @@ def forward(self, x):
return x.permute(*self._dims)

def make_parallel(self, n: int):
return Permute(*self._dims)

"""Create a Permute layer to handle parallel batch.
class BatchSquash(object):
"""Facilitates flattening and unflattening batch dims of a tensor. Copied
from `tf_agents`.
Exposes a pair of matched flatten and unflatten methods. After flattening
only 1 batch dimension will be left. This facilitates evaluating networks
that expect inputs to have only 1 batch dimension.
"""

def __init__(self, batch_dims):
"""Create two tied ops to flatten and unflatten the front dimensions.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for permute.
Args:
batch_dims (int): Number of batch dimensions the flatten/unflatten
ops should handle.
Raises:
ValueError: if batch dims is negative.
n (int): the number of replicas.
Returns:
a ``Permute`` layer to handle parallel batch.
"""
if batch_dims < 0:
raise ValueError('Batch dims must be non-negative.')
self._batch_dims = batch_dims
self._original_tensor_shape = None

def flatten(self, tensor):
"""Flattens and caches the tensor's batch_dims."""
if self._batch_dims == 1:
return tensor
self._original_tensor_shape = tensor.shape
return torch.reshape(tensor,
(-1, ) + tuple(tensor.shape[self._batch_dims:]))

def unflatten(self, tensor):
"""Unflattens the tensor's batch_dims using the cached shape."""
if self._batch_dims == 1:
return tensor

if self._original_tensor_shape is None:
raise ValueError('Please call flatten before unflatten.')

return torch.reshape(
tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) +
tuple(tensor.shape[1:])))
return Permute(*self._dims)


@alf.configurable
Expand Down Expand Up @@ -2112,13 +2088,6 @@ def make_parallel(self, n: int):
return Reshape((n, ) + self._shape)


def _tuplify2d(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
return (x, x)


def _conv_transpose_2d(in_channels,
out_channels,
kernel_size,
Expand Down Expand Up @@ -2665,6 +2634,16 @@ def forward(self, input):
return input.sum(dim=self._dim)

def make_parallel(self, n: int):
"""Create a Sum layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are counted for ``dim``
Args:
n (int): the number of replicas.
Returns:
a ``Sum`` layer to handle parallel batch.
"""
return Sum(self._dim)


Expand Down Expand Up @@ -2956,7 +2935,7 @@ def make_parallel_net(module, n: int):
pnet = make_parallel_net(net, n)
# replicate input.
# pinput will have shape [batch_size, n, ...], if input has shape [batch_size, ...]
pinput = make_parallel_input(input)
pinput = make_parallel_input(input, n)
poutput = pnet(pinput)
If you already have parallel input with shape [batch_size, n, ...], you can
Expand All @@ -2983,7 +2962,7 @@ def __init__(self, module: Union[nn.Module, Callable], n: int):
A parallel network has ``n`` copies of network with the same structure but
different indepently initialized parameters.
``NaiveParallelLayer`` creats ``n`` independent networks with the same
``NaiveParallelLayer`` creates ``n`` independent networks with the same
structure as ``network`` and evaluate them separately in a loop during
``forward()``.
Expand Down
2 changes: 1 addition & 1 deletion alf/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def test_permute(self):
layer = alf.layers.Permute(2, 1, 0)
self._test_make_parallel(layer, input_spec)

def test_onehost(self):
def test_onehot(self):
input_spec = alf.BoundedTensorSpec((10, ),
dtype=torch.int64,
minimum=0,
Expand Down
18 changes: 15 additions & 3 deletions alf/nest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"):
It assumes that all the selected tensors have the same tensor spec.
Can be used as a preprocessing combiner in ``EncodingNetwork``.
Note that batch dimension is not considered for concat. This means that
dim=0 means the first dimension after batch dimension.
Args:
nest_mask (nest|None): nest structured mask indicating which of the
tensors in the nest to be selected or not, indicated by a
Expand All @@ -90,7 +93,7 @@ def __init__(self, nest_mask=None, dim=-1, name="NestConcat"):
super(NestConcat, self).__init__(name)
self._nest_mask = nest_mask
self._flat_mask = nest.flatten(nest_mask) if nest_mask else nest_mask
self._dim = dim
self._dim = dim if dim < 0 else dim + 1

def _combine_flat(self, tensors):
if self._flat_mask is not None:
Expand All @@ -106,8 +109,17 @@ def _combine_flat(self, tensors):
return torch.cat(tensors, dim=self._dim)

def make_parallel(self, n):
dim = self._dim if self._dim < 0 else self._dim + 1
return NestConcat(self._nest_mask, dim, "parallel_" + self._name)
"""Create a NestConcat layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for concat.
Args:
n (int): the number of replicas.
Returns:
a ``Transpose`` layer to handle parallel batch.

This comment has been minimized.

Copy link
@Haichao-Zhang

Haichao-Zhang Aug 25, 2021

Contributor

Transpose -> NestConcat

"""
return NestConcat(self._nest_mask, self._dim, "parallel_" + self._name)


@alf.configurable
Expand Down
2 changes: 1 addition & 1 deletion alf/networks/actor_distribution_networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):
self._conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
self._fc_layer_params = (100, )
self._input_preprocessors = [torch.tanh, None]
self._preprocessing_combiner = NestConcat(dim=1)
self._preprocessing_combiner = NestConcat(dim=0)

def _init(self, lstm_hidden_size):
if lstm_hidden_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion alf/networks/preprocessors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _check_with_shared_param(net1, net2, shared_subnet=None):
TestInputpreprocessor.input_spec
],
input_preprocessors=[input_preprocessor, torch.relu],
preprocessing_combiner=NestConcat(dim=1))
preprocessing_combiner=NestConcat(dim=0))

# 2) test copied network has its own parameters, including
# parameters from input preprocessors
Expand Down
46 changes: 45 additions & 1 deletion alf/utils/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,57 @@
import torch
from typing import Iterable

from alf.layers import BatchSquash
import alf.nest as nest
from alf.nest.utils import get_outer_rank
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from . import dist_utils


class BatchSquash(object):
"""Facilitates flattening and unflattening batch dims of a tensor. Copied
from `tf_agents`.
Exposes a pair of matched flatten and unflatten methods. After flattening
only 1 batch dimension will be left. This facilitates evaluating networks
that expect inputs to have only 1 batch dimension.
"""

def __init__(self, batch_dims):
"""Create two tied ops to flatten and unflatten the front dimensions.
Args:
batch_dims (int): Number of batch dimensions the flatten/unflatten
ops should handle.
Raises:
ValueError: if batch dims is negative.
"""
if batch_dims < 0:
raise ValueError('Batch dims must be non-negative.')
self._batch_dims = batch_dims
self._original_tensor_shape = None

def flatten(self, tensor):
"""Flattens and caches the tensor's batch_dims."""
if self._batch_dims == 1:
return tensor
self._original_tensor_shape = tensor.shape
return torch.reshape(tensor,
(-1, ) + tuple(tensor.shape[self._batch_dims:]))

def unflatten(self, tensor):
"""Unflattens the tensor's batch_dims using the cached shape."""
if self._batch_dims == 1:
return tensor

if self._original_tensor_shape is None:
raise ValueError('Please call flatten before unflatten.')

return torch.reshape(
tensor, (tuple(self._original_tensor_shape[:self._batch_dims]) +
tuple(tensor.shape[1:])))


def spec_means_and_magnitudes(spec: BoundedTensorSpec):
"""Get the center and magnitude of the ranges for the input spec.
Expand Down

0 comments on commit 438ca21

Please sign in to comment.