Skip to content

Commit

Permalink
update ddpg_algorithm to work with batchensemble and fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry committed Mar 28, 2024
1 parent 1c14fc5 commit 8930ae0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 26 deletions.
51 changes: 48 additions & 3 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,18 @@ def __init__(self,
config: TrainerConfig = None,
ou_stddev=0.2,
ou_damping=0.15,
noise_clipping=0.5,
critic_loss_ctor=None,
num_critic_replicas=1,
target_update_tau=0.05,
target_update_period=1,
actor_update_period=1,
rollout_random_action=0.,
dqda_clipping=None,
action_l2=0,
use_batch_ensemble=False,
ensemble_size=10,
input_with_ensemble_ids=False,
actor_optimizer=None,
critic_optimizer=None,
checkpoint=None,
Expand Down Expand Up @@ -124,12 +129,16 @@ def __init__(self,
(OU) noise added in the default collect policy.
ou_damping (float): Damping factor for the OU noise added in the
default collect policy.
noise_clipping (float): when computing the action noise, clips the
noise element-wise between ``[-noise_clipping, noise_clipping]``.
Does not perform clipping if ``noise_clipping == 0``.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
actor_update_period (int): Period for update of the actor_network.
rollout_random_action (float): the probability of taking a uniform
random action during a ``rollout_step()``. 0 means always directly
taking actions added with OU noises and 1 means always sample
Expand All @@ -139,6 +148,17 @@ def __init__(self,
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``.
Does not perform clipping if ``dqda_clipping == 0``.
action_l2 (float): weight of squared action l2-norm on actor loss.
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D
layers. If True, both BatchEnsemble layers will always be created
with ``output_ensemble_ids=True``, and as a result, the output of
the network is a tuple with ensemble_ids.
ensemble_size (int): ensemble size, only effective if use_batch_ensemble
is True.
input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids,
if True, input to the network should be a tuple of two tensors, the
first one is the input data tensor and the second one is the
ensemble_ids. This option is only effective if use_batch_ensemble
is True.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
checkpoint (None|str): a string in the format of "prefix@path",
Expand All @@ -155,9 +175,16 @@ def __init__(self,

critic_network = critic_network_ctor(
input_tensor_spec=(observation_spec, action_spec),
output_tensor_spec=reward_spec)
output_tensor_spec=reward_spec,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=ensemble_size,
input_with_ensemble_ids=input_with_ensemble_ids)
actor_network = actor_network_ctor(
input_tensor_spec=observation_spec, action_spec=action_spec)
input_tensor_spec=observation_spec,
action_spec=action_spec,
use_batch_ensemble=use_batch_ensemble,
ensemble_size=ensemble_size,
input_with_ensemble_ids=input_with_ensemble_ids)

critic_networks = critic_network.make_parallel(num_critic_replicas)

Expand Down Expand Up @@ -221,7 +248,11 @@ def __init__(self,
self._critic_losses[i] = critic_loss_ctor(
name=("critic_loss" + str(i)))

self._use_batch_ensemble = use_batch_ensemble
self._noise_process = noise_process
self._noise_clipping = noise_clipping
self._actor_update_period = actor_update_period
self._train_step_count = 0

self._update_target = common.TargetUpdater(
models=[self._actor_network, self._critic_networks],
Expand Down Expand Up @@ -252,6 +283,9 @@ def _sample(a, noise):
return a

noise, noise_state = self._noise_process(state.noise)
if self._noise_clipping:
noise = torch.clamp(noise, -self._noise_clipping,
self._noise_clipping)
noisy_action = nest.map_structure(_sample, action, noise)
noisy_action = nest.map_structure(spec_utils.clip_to_spec,
noisy_action, self._action_spec)
Expand Down Expand Up @@ -289,6 +323,8 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,
inputs.observation, state=state.target_actor)
target_q_values, target_critic_states = self._target_critic_networks(
(inputs.observation, target_action), state=state.target_critics)
if self._use_batch_ensemble:
target_q_values = target_q_values[0]

if self.has_multidim_reward():
sign = self.reward_weights.sign()
Expand All @@ -298,6 +334,8 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,

q_values, critic_states = self._critic_networks(
(inputs.observation, rollout_info.action), state=state.critics)
if self._use_batch_ensemble:
q_values = q_values[0]

state = DdpgCriticState(
critics=critic_states,
Expand All @@ -315,6 +353,8 @@ def _actor_train_step(self, inputs: TimeStep, state: DdpgActorState):

q_values, critic_states = self._critic_networks(
(inputs.observation, action), state=state.critics)
if self._use_batch_ensemble:
q_values = q_values[0]
if self.has_multidim_reward():
# Multidimensional reward: [B, replicas, reward_dim]
q_values = q_values * self.reward_weights
Expand Down Expand Up @@ -343,9 +383,14 @@ def actor_loss_fn(dqda, action):

def train_step(self, inputs: TimeStep, state: DdpgState,
rollout_info: DdpgInfo):
self._train_step_count += 1
critic_states, critic_info = self._critic_train_step(
inputs=inputs, state=state.critics, rollout_info=rollout_info)
policy_step = self._actor_train_step(inputs=inputs, state=state.actor)
if self._train_step_count % self._actor_update_period == 0:
policy_step = self._actor_train_step(
inputs=inputs, state=state.actor)
else:
policy_step = AlgStep(state=state.actor)
return policy_step._replace(
state=state._replace(
actor=policy_step.state, critics=critic_states),
Expand Down
24 changes: 18 additions & 6 deletions alf/algorithms/ddpg_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@


class DDPGAlgorithmTest(parameterized.TestCase, alf.test.TestCase):
@parameterized.parameters((1, 1, None), (2, 3, [1, 2, 3]))
def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
reward_weights):
@parameterized.parameters((1, 1, None), (1, 1, None, True),
(2, 3, [1, 2, 3]))
def test_ddpg_algorithm(self,
num_critic_replicas,
reward_dim,
reward_weights,
use_batch_ensemble=False):
num_env = 128
num_eval_env = 100
steps_per_episode = 13
Expand Down Expand Up @@ -65,7 +69,13 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
obs_spec = env._observation_spec
action_spec = env._action_spec

fc_layer_params = (16, 16)
if use_batch_ensemble:
n_neuron = 32
init_lr = 2e-3
else:
n_neuron = 16
init_lr = 1e-2
fc_layer_params = (n_neuron, n_neuron)

actor_network = functools.partial(
ActorNetwork, fc_layer_params=fc_layer_params)
Expand All @@ -86,8 +96,10 @@ def test_ddpg_algorithm(self, num_critic_replicas, reward_dim,
env=env,
config=config,
num_critic_replicas=num_critic_replicas,
actor_optimizer=alf.optimizers.Adam(lr=1e-2),
critic_optimizer=alf.optimizers.Adam(lr=1e-2),
use_batch_ensemble=use_batch_ensemble,
ensemble_size=3,
actor_optimizer=alf.optimizers.Adam(lr=init_lr),
critic_optimizer=alf.optimizers.Adam(lr=init_lr),
debug_summaries=False,
name="MyDDPG")

Expand Down
4 changes: 2 additions & 2 deletions alf/networks/actor_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def __init__(self,
with uniform distribution will be used.
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D
layers. If True, both BatchEnsemble layers will always be created
with ``output_ensemble_ids=True``, and as a result, the output of
the network is a tuple with ensemble_ids.
with ``output_ensemble_ids=True``, however, the output of action
network will not contrain the ensemble_ids.
ensemble_size (int): ensemble size, only effective if use_batch_ensemble
is True.
input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids,
Expand Down
36 changes: 21 additions & 15 deletions alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,27 @@ def __init__(self,

if observation_action_combiner is None:
if use_batch_ensemble:
ensemble_ids_spec = TensorSpec((), dtype=torch.int64)
obs_action_spec = (obs_encoder.output_spec,
action_encoder.output_spec)
observation_action_combiner = Sequential(
*[
NetworkWrapper(lambda x: nest.transpose(x),
obs_action_spec),
Parallel(
(alf.layers.NestConcat(dim=-1), lambda x: x[0]),
((obs_encoder.output_spec[0],
action_encoder.output_spec[0]),
(TensorSpec((), dtype=torch.int64),
TensorSpec((), dtype=torch.int64))))
],
input_tensor_spec=obs_action_spec)
obs_spec = obs_encoder.output_spec
action_spec = action_encoder.output_spec
obs_action_spec = (obs_spec, action_spec)

def _obs_action_combiner(inputs):
obs, action = inputs
ensemble_ids = None
if isinstance(obs_spec, tuple):
ensemble_ids = obs[1]
obs = obs[0]
if isinstance(action_spec, tuple):
if ensemble_ids is None:
ensemble_ids = action[1]
action = action[0]
outputs = alf.layers.NestConcat(dim=-1)((obs, action))
if ensemble_ids is not None:
outputs = (outputs, ensemble_ids)
return outputs

observation_action_combiner = NetworkWrapper(
_obs_action_combiner, obs_action_spec)
else:
observation_action_combiner = alf.layers.NestConcat(dim=-1)

Expand Down

0 comments on commit 8930ae0

Please sign in to comment.