Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor critic networks using the new encoding network #993

Merged
merged 2 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,8 +1352,13 @@ def __init__(self,
use_bias = not use_bn
self._activation = activation
self._n = n
self._use_bias = use_bias
self._in_channels = in_channels
self._out_channels = out_channels
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value

self._kernel_size = common.tuplify2d(kernel_size)
self._conv2d = nn.Conv2d(
in_channels * n,
Expand All @@ -1364,34 +1369,30 @@ def __init__(self,
padding=padding,
bias=use_bias)

for i in range(n):
if kernel_initializer is None:
if use_bn:
self._bn = nn.BatchNorm2d(n * out_channels)
else:
self._bn = None
self.reset_parameters()

def reset_parameters(self):
for i in range(self._n):
if self._kernel_initializer is None:
variance_scaling_init(
self._conv2d.weight.data[i * out_channels:(i + 1) *
out_channels],
gain=kernel_init_gain,
self._conv2d.weight.data[i * self._out_channels:(i + 1) *
self._out_channels],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
kernel_initializer(
self._conv2d.weight.data[i * out_channels:(i + 1) *
out_channels])

# [n*C', C, kernel_size, kernel_size]->[n, C', C, kernel_size, kernel_size]
self._weight = self._conv2d.weight.view(
self._n, self._out_channels, self._in_channels,
self._kernel_size[0], self._kernel_size[1])
self._kernel_initializer(
self._conv2d.weight.data[i * self._out_channels:(i + 1) *
self._out_channels])

if use_bias:
nn.init.constant_(self._conv2d.bias.data, bias_init_value)
# [n*C']->[n, C']
self._bias = self._conv2d.bias.view(self._n, self._out_channels)
else:
self._bias = None
if self._use_bias:
nn.init.constant_(self._conv2d.bias.data, self._bias_init_value)

if use_bn:
self._bn = nn.BatchNorm2d(n * out_channels)
else:
self._bn = None
if self._bn:
self._bn.reset_parameters()

def forward(self, img):
"""Forward
Expand Down Expand Up @@ -1454,11 +1455,22 @@ def forward(self, img):

@property
def weight(self):
return self._weight
# The reason that weight cannot pre-computed at __init__ is deepcopy will
# fail. deepcopy is needed to implement the copy for the container networks.
# [n*C', C, kernel_size, kernel_size]->[n, C', C, kernel_size, kernel_size]
return self._conv2d.weight.view(
self._n, self._out_channels, self._in_channels,
self._kernel_size[0], self._kernel_size[1])

@property
def bias(self):
return self._bias
if self._use_bias:
# The reason that weight cannot pre-computed at __init__ is deepcopy will
# fail. deepcopy is needed to implement the copy for the container networks.
# [n*C']->[n, C']
return self._conv2d.bias.view(self._n, self._out_channels)
else:
return None


@alf.configurable
Expand Down
160 changes: 31 additions & 129 deletions alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
import math

import torch
import torch.nn as nn

import alf
import alf.utils.math_ops as math_ops
import alf.nest as nest
from alf.initializers import variance_scaling_init
from alf.tensor_specs import TensorSpec

from .network import Network
from .encoding_networks import EncodingNetwork, LSTMEncodingNetwork, ParallelEncodingNetwork


Expand Down Expand Up @@ -54,12 +52,15 @@ def _check_individual(spec, proc):


@alf.configurable
class CriticNetwork(Network):
class CriticNetwork(EncodingNetwork):
"""Creates an instance of ``CriticNetwork`` for estimating action-value of
continuous or discrete actions. The action-value is defined as the expected
return starting from the given input observation and taking the given action.
This module takes observation as input and action as input and outputs an
action-value tensor with the shape of ``[batch_size]``.

The network take a tuple of (observation, action) as input to computes the
action-value given an observation.
"""

def __init__(self,
Expand Down Expand Up @@ -119,8 +120,6 @@ def __init__(self,
situation.
name (str):
"""
super().__init__(input_tensor_spec, name=name)

if kernel_initializer is None:
kernel_initializer = functools.partial(
variance_scaling_init,
Expand All @@ -130,7 +129,7 @@ def __init__(self,

observation_spec, action_spec = input_tensor_spec

self._obs_encoder = EncodingNetwork(
obs_encoder = EncodingNetwork(
observation_spec,
input_preprocessors=observation_input_processors,
preprocessing_combiner=observation_preprocessing_combiner,
Expand All @@ -139,124 +138,53 @@ def __init__(self,
activation=activation,
kernel_initializer=kernel_initializer,
use_fc_bn=use_fc_bn,
name=self.name + ".obs_encoder")
name=name + ".obs_encoder")

_check_action_specs_for_critic_networks(action_spec,
action_input_processors,
action_preprocessing_combiner)
self._action_encoder = EncodingNetwork(
action_encoder = EncodingNetwork(
action_spec,
input_preprocessors=action_input_processors,
preprocessing_combiner=action_preprocessing_combiner,
fc_layer_params=action_fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer,
use_fc_bn=use_fc_bn,
name=self.name + ".action_encoder")
name=name + ".action_encoder")

last_kernel_initializer = functools.partial(
torch.nn.init.uniform_, a=-0.003, b=0.003)

self._joint_encoder = EncodingNetwork(
TensorSpec((self._obs_encoder.output_spec.shape[0] +
self._action_encoder.output_spec.shape[0], )),
super().__init__(
input_tensor_spec=input_tensor_spec,
output_tensor_spec=output_tensor_spec,
input_preprocessors=(obs_encoder, action_encoder),
preprocessing_combiner=alf.layers.NestConcat(dim=-1),
fc_layer_params=joint_fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer,
last_layer_size=output_tensor_spec.numel,
last_activation=math_ops.identity,
use_fc_bn=use_fc_bn,
last_kernel_initializer=last_kernel_initializer,
name=self.name + ".joint_encoder")

name=name)
self._use_naive_parallel_network = use_naive_parallel_network
self._output_spec = output_tensor_spec

def forward(self, inputs, state=()):
"""Computes action-value given an observation.

Args:
inputs: A tuple of Tensors consistent with ``input_tensor_spec``
state: empty for API consistent with ``CriticRNNNetwork``

Returns:
tuple:
- action_value (torch.Tensor): a tensor of the size ``[batch_size]``
- state: empty
"""
observations, actions = inputs

encoded_obs, _ = self._obs_encoder(observations)
encoded_action, _ = self._action_encoder(actions)
joint = torch.cat([encoded_obs, encoded_action], -1)
action_value, _ = self._joint_encoder(joint)
action_value = action_value.reshape(action_value.shape[0],
*self._output_spec.shape)
return action_value, state

def make_parallel(self, n):
"""Create a ``ParallelCriticNetwork`` using ``n`` replicas of ``self``.
"""Create a parallel critic network using ``n`` replicas of ``self``.
The initialized network parameters will be different.
If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork``
to create the parallel network.
"""
if self._use_naive_parallel_network:
return alf.networks.NaiveParallelNetwork(self, n)
else:
return ParallelCriticNetwork(self, n, "parallel_" + self._name)


class ParallelCriticNetwork(Network):
"""Perform ``n`` critic computations in parallel."""

def __init__(self,
critic_network: CriticNetwork,
n: int,
name="ParallelCriticNetwork"):
"""
It create a parallelized version of ``critic_network``.

Args:
critic_network (CriticNetwork): non-parallelized critic network
n (int): make ``n`` replicas from ``critic_network`` with different
initialization.
name (str):
"""
super().__init__(
input_tensor_spec=critic_network.input_tensor_spec, name=name)
self._obs_encoder = critic_network._obs_encoder.make_parallel(n, True)
self._action_encoder = critic_network._action_encoder.make_parallel(
n, True)
self._joint_encoder = critic_network._joint_encoder.make_parallel(n)
self._output_spec = TensorSpec((n, ) +
critic_network.output_spec.shape)

def forward(self, inputs, state=()):
"""Computes action-value given an observation.

Args:
inputs (tuple): A tuple of Tensors consistent with `input_tensor_spec``.
state (tuple): Empty for API consistent with ``CriticRNNNetwork``.

Returns:
tuple:
- action_value (torch.Tensor): a tensor of shape :math:`[B,n]`, where
:math:`B` is the batch size.
- state: empty
"""
observations, actions = inputs

encoded_obs, _ = self._obs_encoder(observations)
encoded_action, _ = self._action_encoder(actions)
joint = torch.cat([encoded_obs, encoded_action], -1)
action_value, _ = self._joint_encoder(joint)
action_value = action_value.reshape(action_value.shape[0],
*self._output_spec.shape)
return action_value, state
return super().make_parallel(n, True)


@alf.configurable
class CriticRNNNetwork(Network):
class CriticRNNNetwork(LSTMEncodingNetwork):
"""Creates an instance of ``CriticRNNNetwork`` for estimating action-value
of continuous or discrete actions. The action-value is defined as the
expected return starting from the given inputs (observation and state) and
Expand Down Expand Up @@ -318,8 +246,6 @@ def __init__(self,
with uniform distribution will be used.
name (str):
"""
super().__init__(input_tensor_spec, name=name)

if kernel_initializer is None:
kernel_initializer = functools.partial(
variance_scaling_init,
Expand All @@ -329,7 +255,7 @@ def __init__(self,

observation_spec, action_spec = input_tensor_spec

self._obs_encoder = EncodingNetwork(
obs_encoder = EncodingNetwork(
observation_spec,
input_preprocessors=observation_input_processors,
preprocessing_combiner=observation_preprocessing_combiner,
Expand All @@ -341,26 +267,23 @@ def __init__(self,
_check_action_specs_for_critic_networks(action_spec,
action_input_processors,
action_preprocessing_combiner)
self._action_encoder = EncodingNetwork(
action_encoder = EncodingNetwork(
action_spec,
input_preprocessors=action_input_processors,
preprocessing_combiner=action_preprocessing_combiner,
fc_layer_params=action_fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer)

self._joint_encoder = EncodingNetwork(
TensorSpec((self._obs_encoder.output_spec.shape[0] +
self._action_encoder.output_spec.shape[0], )),
fc_layer_params=joint_fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer)

last_kernel_initializer = functools.partial(
torch.nn.init.uniform_, a=-0.003, b=0.003)

self._lstm_encoding_net = LSTMEncodingNetwork(
input_tensor_spec=self._joint_encoder.output_spec,
super().__init__(
input_tensor_spec=input_tensor_spec,
output_tensor_spec=output_tensor_spec,
input_preprocessors=(obs_encoder, action_encoder),
preprocessing_combiner=alf.layers.NestConcat(dim=-1),
pre_fc_layer_params=joint_fc_layer_params,
hidden_size=lstm_hidden_size,
post_fc_layer_params=critic_fc_layer_params,
activation=activation,
Expand All @@ -369,31 +292,10 @@ def __init__(self,
last_activation=math_ops.identity,
last_kernel_initializer=last_kernel_initializer)

self._output_spec = output_tensor_spec

def forward(self, inputs, state):
"""Computes action-value given an observation.

Args:
inputs: A tuple of Tensors consistent with ``input_tensor_spec``
state (nest[tuple]): a nest structure of state tuples ``(h, c)``

Returns:
tuple:
- action_value (torch.Tensor): a tensor of the size ``[batch_size]``
- new_state (nest[tuple]): the updated states
def make_parallel(self, n):
"""Create a parallel critic RNN network using ``n`` replicas of ``self``.
The initialized network parameters will be different.
If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork``
to create the parallel network.
"""
observations, actions = inputs

encoded_obs, _ = self._obs_encoder(observations)
encoded_action, _ = self._action_encoder(actions)
joint = torch.cat([encoded_obs, encoded_action], -1)
encoded_joint, _ = self._joint_encoder(joint)
action_value, state = self._lstm_encoding_net(encoded_joint, state)
action_value = action_value.reshape(action_value.shape[0],
*self._output_spec.shape)
return action_value, state

@property
def state_spec(self):
return self._lstm_encoding_net.state_spec
return super().make_parallel(n, True)
Loading