Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch ensemble ddpg #1633

Open
wants to merge 5 commits into
base: pytorch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 117 additions & 10 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@
DdpgActorState = namedtuple(
"DdpgActorState", ['actor', 'critics'], default_value=())
DdpgState = namedtuple(
"DdpgState", ['actor', 'critics', 'noise'], default_value=())
"DdpgState", ['actor', 'critics', 'noise', 'ensemble_ids'],
default_value=())
DdpgInfo = namedtuple(
"DdpgInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor_loss", "critic", "discounted_return"
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor_loss",
"critic",
"discounted_return",
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
Expand Down Expand Up @@ -74,13 +81,18 @@ def __init__(self,
config: TrainerConfig = None,
ou_stddev=0.2,
ou_damping=0.15,
noise_clipping=None,
critic_loss_ctor=None,
num_critic_replicas=1,
target_update_tau=0.05,
target_update_period=1,
actor_update_period=1,
rollout_random_action=0.,
dqda_clipping=None,
action_l2=0,
use_batch_ensemble=False,
ensemble_size=10,
input_with_ensemble_ids=False,
actor_optimizer=None,
critic_optimizer=None,
checkpoint=None,
Expand Down Expand Up @@ -124,12 +136,16 @@ def __init__(self,
(OU) noise added in the default collect policy.
ou_damping (float): Damping factor for the OU noise added in the
default collect policy.
noise_clipping (float): when computing the action noise, clips the
noise element-wise between ``[-noise_clipping, noise_clipping]``.
Does not perform clipping if ``noise_clipping == 0``.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
actor_update_period (int): Period for update of the actor_network.
rollout_random_action (float): the probability of taking a uniform
random action during a ``rollout_step()``. 0 means always directly
taking actions added with OU noises and 1 means always sample
Expand All @@ -139,6 +155,17 @@ def __init__(self,
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``.
Does not perform clipping if ``dqda_clipping == 0``.
action_l2 (float): weight of squared action l2-norm on actor loss.
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we might should make these batch ensemble related parameters transparent to the ddpg_algorithm? Basically, the ddpg_algorithm should not use batch_ensemble related parameters in the ideal case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Currently ddpg needs the use_batch_ensemble to do some post processing when forwarding critic networks during training. Let me think it over if there might be some alternative methods to work around.

layers. If True, both BatchEnsemble layers will always be created
with ``output_ensemble_ids=True``, and as a result, the output of
the network is a tuple with ensemble_ids.
ensemble_size (int): ensemble size, only effective if use_batch_ensemble
is True.
input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids,
if True, input to the network should be a tuple of two tensors, the
first one is the input data tensor and the second one is the
ensemble_ids. This option is only effective if use_batch_ensemble
is True.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
checkpoint (None|str): a string in the format of "prefix@path",
Expand All @@ -148,16 +175,27 @@ def __init__(self,
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
if use_batch_ensemble:
assert config.use_rollout_state, (
'use_rollout_state needs to be True when use_batch_ensemble.')

self._calculate_priority = calculate_priority
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
self._epsilon_greedy = epsilon_greedy

critic_network = critic_network_ctor(
input_tensor_spec=(observation_spec, action_spec),
output_tensor_spec=reward_spec)
output_tensor_spec=reward_spec,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=ensemble_size,
input_with_ensemble_ids=input_with_ensemble_ids)
actor_network = actor_network_ctor(
input_tensor_spec=observation_spec, action_spec=action_spec)
input_tensor_spec=observation_spec,
action_spec=action_spec,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=ensemble_size,
input_with_ensemble_ids=input_with_ensemble_ids)

critic_networks = critic_network.make_parallel(num_critic_replicas)

Expand All @@ -176,6 +214,8 @@ def __init__(self,

train_state_spec = DdpgState(
noise=noise_state,
ensemble_ids=TensorSpec(
(), dtype=torch.int64) if use_batch_ensemble else (),
actor=DdpgActorState(
actor=actor_network.state_spec,
critics=critic_networks.state_spec),
Expand Down Expand Up @@ -221,7 +261,11 @@ def __init__(self,
self._critic_losses[i] = critic_loss_ctor(
name=("critic_loss" + str(i)))

self._use_batch_ensemble = use_batch_ensemble
self._noise_process = noise_process
self._noise_clipping = noise_clipping
self._actor_update_period = actor_update_period
self._train_step_count = 0

self._update_target = common.TargetUpdater(
models=[self._actor_network, self._critic_networks],
Expand All @@ -239,6 +283,9 @@ def predict_step(self, inputs: TimeStep, state):
def _predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.):
action, actor_state = self._actor_network(
time_step.observation, state=state.actor.actor)
if self._use_batch_ensemble:
ensemble_ids = action[1]
action = action[0]
empty_state = nest.map_structure(lambda x: (), self.rollout_state_spec)

def _sample(a, noise):
Expand All @@ -252,11 +299,15 @@ def _sample(a, noise):
return a

noise, noise_state = self._noise_process(state.noise)
if self._noise_clipping:
noise = torch.clamp(noise, -self._noise_clipping,
self._noise_clipping)
noisy_action = nest.map_structure(_sample, action, noise)
noisy_action = nest.map_structure(spec_utils.clip_to_spec,
noisy_action, self._action_spec)
state = empty_state._replace(
noise=noise_state,
ensemble_ids=ensemble_ids if self._use_batch_ensemble else (),
actor=DdpgActorState(actor=actor_state, critics=()))

return AlgStep(
Expand All @@ -265,9 +316,12 @@ def _sample(a, noise):
info=DdpgInfo(action=noisy_action, action_distribution=action))

def rollout_step(self, time_step: TimeStep, state: DdpgState = None):
if self.need_full_rollout_state():
raise NotImplementedError("Storing RNN state to replay buffer "
"is not supported by DdpgAlgorithm")
"""``rollout_step()`` basically predicts actions like what is done by
``predict_step()``. Additionally, if states are to be stored a in replay
buffer, then this function also call ``_critic_networks``,
``_target_critic_networks``, and ``_target_actor_network`` to maintain
their states.
"""

def _update_random_action(spec, noisy_action):
random_action = spec_utils.scale_to_spec(
Expand All @@ -277,18 +331,52 @@ def _update_random_action(spec, noisy_action):
_rollout_random_action)
noisy_action[ind[0], :] = random_action[ind[0], :]

observation = time_step.observation
if self._use_batch_ensemble and torch.count_nonzero(
state.ensemble_ids) > 0:
# If use_batch_ensemble, we want to use the same ensemble_ids
# to forward the actor_network during the rollout of an episode,
# except for the initial rollout_step, where the ensemble_ids
# in the initial rollout_state are all zeros.
time_step = time_step._replace(
observation=(observation, state.ensemble_ids))
pred_step = self._predict_step(time_step, state, epsilon_greedy=1.0)
if self._rollout_random_action > 0:
nest.map_structure(_update_random_action, self._action_spec,
pred_step.output)
return pred_step

if self.need_full_rollout_state():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want the algorithm use the same ensemble_id during an entire episode. This means that it should store ensembled_id in state and use the same ensemble_id to call actor_network

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, good point, I think that is the reason why I had to tweak the ddpg_algorithm_test to pass the toy unittest. Updated.

_, critics_state = self._critic_networks(
(observation, pred_step.output), state.critics.critics)
_, target_critics_state = self._target_critic_networks(
(observation, pred_step.output), state.critics.target_critics)
_, target_actor_state = self._target_actor_network(
observation, state=state.critics.target_actor)
critic_state = DdpgCriticState(
critics=critics_state,
target_actor=target_actor_state,
target_critics=target_critics_state)
else:
critics_state = state.critics.critics
critic_state = state.critics

actor_state = pred_step.state.actor._replace(critics=critics_state)

new_state = pred_step.state._replace(
actor=actor_state, critics=critic_state)

return pred_step._replace(state=new_state)

def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,
rollout_info: DdpgInfo):
target_action, target_actor_state = self._target_actor_network(
inputs.observation, state=state.target_actor)
if self._use_batch_ensemble:
target_action = target_action[0]
target_q_values, target_critic_states = self._target_critic_networks(
(inputs.observation, target_action), state=state.target_critics)
if self._use_batch_ensemble:
target_q_values = target_q_values[0]

if self.has_multidim_reward():
sign = self.reward_weights.sign()
Expand All @@ -298,6 +386,8 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,

q_values, critic_states = self._critic_networks(
(inputs.observation, rollout_info.action), state=state.critics)
if self._use_batch_ensemble:
q_values = q_values[0]

state = DdpgCriticState(
critics=critic_states,
Expand All @@ -312,9 +402,13 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,
def _actor_train_step(self, inputs: TimeStep, state: DdpgActorState):
action, actor_state = self._actor_network(
inputs.observation, state=state.actor)
if self._use_batch_ensemble:
action = action[0]

q_values, critic_states = self._critic_networks(
(inputs.observation, action), state=state.critics)
if self._use_batch_ensemble:
q_values = q_values[0]
if self.has_multidim_reward():
# Multidimensional reward: [B, replicas, reward_dim]
q_values = q_values * self.reward_weights
Expand Down Expand Up @@ -343,9 +437,22 @@ def actor_loss_fn(dqda, action):

def train_step(self, inputs: TimeStep, state: DdpgState,
rollout_info: DdpgInfo):
self._train_step_count += 1
critic_states, critic_info = self._critic_train_step(
inputs=inputs, state=state.critics, rollout_info=rollout_info)
policy_step = self._actor_train_step(inputs=inputs, state=state.actor)
if self._train_step_count % self._actor_update_period == 0:
policy_step = self._actor_train_step(
inputs=inputs, state=state.actor)
critic_states = critic_states._replace(
critics=policy_step.state.critics)
else:
batch_dims = nest_utils.get_outer_rank(inputs.prev_action,
self._action_spec)
loss = torch.zeros(*inputs.prev_action.shape[:batch_dims])
policy_step = AlgStep(
output=torch.zeros_like(inputs.prev_action),
state=state.actor,
info=LossInfo(loss=loss, extra=loss))
return policy_step._replace(
state=state._replace(
actor=policy_step.state, critics=critic_states),
Expand Down
15 changes: 12 additions & 3 deletions alf/algorithms/ddpg_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@


class DDPGAlgorithmTest(parameterized.TestCase, alf.test.TestCase):
@parameterized.parameters((1, 1, None), (2, 3, [1, 2, 3]))
def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
reward_weights):
@parameterized.parameters((1, 1, None, 2), (1, 1, None, 2, True),
(2, 3, [1, 2, 3]))
def test_ddpg_algorithm(self,
num_critic_replicas,
reward_dim,
reward_weights,
actor_update_period=1,
use_batch_ensemble=False):
num_env = 128
num_eval_env = 100
steps_per_episode = 13
Expand All @@ -45,6 +50,7 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
mini_batch_length=2,
mini_batch_size=128,
initial_collect_steps=steps_per_episode,
use_rollout_state=use_batch_ensemble,
whole_replay_buffer_training=False,
clear_replay_buffer=False,
)
Expand Down Expand Up @@ -86,6 +92,9 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
env=env,
config=config,
num_critic_replicas=num_critic_replicas,
actor_update_period=actor_update_period,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=3,
actor_optimizer=alf.optimizers.Adam(lr=1e-2),
critic_optimizer=alf.optimizers.Adam(lr=1e-2),
debug_summaries=False,
Expand Down
41 changes: 38 additions & 3 deletions alf/networks/actor_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import alf.nest as nest
from alf.initializers import variance_scaling_init
from alf.networks import Network
from alf.networks.containers import Parallel
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import common, math_ops, spec_utils

Expand Down Expand Up @@ -85,10 +86,21 @@ def __init__(self,
a=-0.003, b=0.003)
self._action_layers = nn.ModuleList()
self._squashing_func = squashing_func
fc_layer_ctor = layers.FC
encoder_output_spec = self._encoding_net.output_spec
self._use_batch_ensemble = encoder_kwargs.get('use_batch_ensemble',
False)
if self._use_batch_ensemble:
encoder_output_spec = encoder_output_spec[0]
fc_layer_ctor = functools.partial(
layers.FCBatchEnsemble,
ensemble_size=encoder_kwargs.get('ensemble_size', 10),
output_ensemble_ids=False)

for single_action_spec in flat_action_spec:
self._action_layers.append(
layers.FC(
self._encoding_net.output_spec.shape[0],
fc_layer_ctor(
encoder_output_spec.shape[0],
single_action_spec.shape[0],
kernel_initializer=last_kernel_initializer))

Expand Down Expand Up @@ -134,6 +146,11 @@ def forward(self, observation, state=()):
i += 1

output_actions = nest.pack_sequence_as(self._action_spec, actions)
if self._use_batch_ensemble:
# note that when use_batch_ensemble, EncodingNetwork always
# outputs a tuple (output_tensor, ensemble_ids)
output_actions = (output_actions, encoded_obs[1])

return output_actions, state

@property
Expand All @@ -149,12 +166,16 @@ def __init__(self,
input_tensor_spec: TensorSpec,
action_spec: BoundedTensorSpec,
input_preprocessors=None,
input_preprocessors_ctor=None,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=None,
activation=torch.relu_,
squashing_func=torch.tanh,
kernel_initializer=None,
use_batch_ensemble=False,
ensemble_size=10,
input_with_ensemble_ids=False,
name="ActorNetwork"):
"""Creates an instance of ``ActorNetwork``, which maps the inputs to
actions (single or nested) through a sequence of deterministic layers.
Expand Down Expand Up @@ -189,6 +210,17 @@ def __init__(self,
kernel_initializer (Callable): initializer for all the layers but
the last layer. If none is provided a ``variance_scaling_initializer``
with uniform distribution will be used.
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D
layers. If True, both BatchEnsemble layers will always be created
with ``output_ensemble_ids=True``, and as a result, the output of
the network is a tuple of (outputs, ensemble_ids).
ensemble_size (int): ensemble size, only effective if use_batch_ensemble
is True.
input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids,
if True, input to the network should be a tuple of two tensors, the
first one is the input data tensor and the second one is the
ensemble_ids. This option is only effective if use_batch_ensemble
is True.
name (str): name of the network
"""
super(ActorNetwork, self).__init__(
Expand All @@ -202,7 +234,10 @@ def __init__(self,
conv_layer_params=conv_layer_params,
fc_layer_params=fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer)
kernel_initializer=kernel_initializer,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=ensemble_size,
input_with_ensemble_ids=input_with_ensemble_ids)


@alf.configurable
Expand Down
Loading
Loading