-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch ensemble ddpg #1633
base: pytorch
Are you sure you want to change the base?
Batch ensemble ddpg #1633
Changes from all commits
1c14fc5
ad5f40e
dae6012
d0d82d4
8d31c76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,11 +42,18 @@ | |
DdpgActorState = namedtuple( | ||
"DdpgActorState", ['actor', 'critics'], default_value=()) | ||
DdpgState = namedtuple( | ||
"DdpgState", ['actor', 'critics', 'noise'], default_value=()) | ||
"DdpgState", ['actor', 'critics', 'noise', 'ensemble_ids'], | ||
default_value=()) | ||
DdpgInfo = namedtuple( | ||
"DdpgInfo", [ | ||
"reward", "step_type", "discount", "action", "action_distribution", | ||
"actor_loss", "critic", "discounted_return" | ||
"reward", | ||
"step_type", | ||
"discount", | ||
"action", | ||
"action_distribution", | ||
"actor_loss", | ||
"critic", | ||
"discounted_return", | ||
], | ||
default_value=()) | ||
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic')) | ||
|
@@ -74,13 +81,18 @@ def __init__(self, | |
config: TrainerConfig = None, | ||
ou_stddev=0.2, | ||
ou_damping=0.15, | ||
noise_clipping=None, | ||
critic_loss_ctor=None, | ||
num_critic_replicas=1, | ||
target_update_tau=0.05, | ||
target_update_period=1, | ||
actor_update_period=1, | ||
rollout_random_action=0., | ||
dqda_clipping=None, | ||
action_l2=0, | ||
use_batch_ensemble=False, | ||
ensemble_size=10, | ||
input_with_ensemble_ids=False, | ||
actor_optimizer=None, | ||
critic_optimizer=None, | ||
checkpoint=None, | ||
|
@@ -124,12 +136,16 @@ def __init__(self, | |
(OU) noise added in the default collect policy. | ||
ou_damping (float): Damping factor for the OU noise added in the | ||
default collect policy. | ||
noise_clipping (float): when computing the action noise, clips the | ||
noise element-wise between ``[-noise_clipping, noise_clipping]``. | ||
Does not perform clipping if ``noise_clipping == 0``. | ||
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss | ||
constructor. If ``None``, a default ``OneStepTDLoss`` will be used. | ||
target_update_tau (float): Factor for soft update of the target | ||
networks. | ||
target_update_period (int): Period for soft update of the target | ||
networks. | ||
actor_update_period (int): Period for update of the actor_network. | ||
rollout_random_action (float): the probability of taking a uniform | ||
random action during a ``rollout_step()``. 0 means always directly | ||
taking actions added with OU noises and 1 means always sample | ||
|
@@ -139,6 +155,17 @@ def __init__(self, | |
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``. | ||
Does not perform clipping if ``dqda_clipping == 0``. | ||
action_l2 (float): weight of squared action l2-norm on actor loss. | ||
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D | ||
layers. If True, both BatchEnsemble layers will always be created | ||
with ``output_ensemble_ids=True``, and as a result, the output of | ||
the network is a tuple with ensemble_ids. | ||
ensemble_size (int): ensemble size, only effective if use_batch_ensemble | ||
is True. | ||
input_with_ensemble_ids (bool): whether handle inputs with ensemble_ids, | ||
if True, input to the network should be a tuple of two tensors, the | ||
first one is the input data tensor and the second one is the | ||
ensemble_ids. This option is only effective if use_batch_ensemble | ||
is True. | ||
actor_optimizer (torch.optim.optimizer): The optimizer for actor. | ||
critic_optimizer (torch.optim.optimizer): The optimizer for critic. | ||
checkpoint (None|str): a string in the format of "prefix@path", | ||
|
@@ -148,16 +175,27 @@ def __init__(self, | |
debug_summaries (bool): True if debug summaries should be created. | ||
name (str): The name of this algorithm. | ||
""" | ||
if use_batch_ensemble: | ||
assert config.use_rollout_state, ( | ||
'use_rollout_state needs to be True when use_batch_ensemble.') | ||
|
||
self._calculate_priority = calculate_priority | ||
if epsilon_greedy is None: | ||
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config) | ||
self._epsilon_greedy = epsilon_greedy | ||
|
||
critic_network = critic_network_ctor( | ||
input_tensor_spec=(observation_spec, action_spec), | ||
output_tensor_spec=reward_spec) | ||
output_tensor_spec=reward_spec, | ||
use_batch_ensemble=use_batch_ensemble, | ||
ensemble_size=ensemble_size, | ||
input_with_ensemble_ids=input_with_ensemble_ids) | ||
actor_network = actor_network_ctor( | ||
input_tensor_spec=observation_spec, action_spec=action_spec) | ||
input_tensor_spec=observation_spec, | ||
action_spec=action_spec, | ||
use_batch_ensemble=use_batch_ensemble, | ||
ensemble_size=ensemble_size, | ||
input_with_ensemble_ids=input_with_ensemble_ids) | ||
|
||
critic_networks = critic_network.make_parallel(num_critic_replicas) | ||
|
||
|
@@ -176,6 +214,8 @@ def __init__(self, | |
|
||
train_state_spec = DdpgState( | ||
noise=noise_state, | ||
ensemble_ids=TensorSpec( | ||
(), dtype=torch.int64) if use_batch_ensemble else (), | ||
actor=DdpgActorState( | ||
actor=actor_network.state_spec, | ||
critics=critic_networks.state_spec), | ||
|
@@ -221,7 +261,11 @@ def __init__(self, | |
self._critic_losses[i] = critic_loss_ctor( | ||
name=("critic_loss" + str(i))) | ||
|
||
self._use_batch_ensemble = use_batch_ensemble | ||
self._noise_process = noise_process | ||
self._noise_clipping = noise_clipping | ||
self._actor_update_period = actor_update_period | ||
self._train_step_count = 0 | ||
|
||
self._update_target = common.TargetUpdater( | ||
models=[self._actor_network, self._critic_networks], | ||
|
@@ -239,6 +283,9 @@ def predict_step(self, inputs: TimeStep, state): | |
def _predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.): | ||
action, actor_state = self._actor_network( | ||
time_step.observation, state=state.actor.actor) | ||
if self._use_batch_ensemble: | ||
ensemble_ids = action[1] | ||
action = action[0] | ||
empty_state = nest.map_structure(lambda x: (), self.rollout_state_spec) | ||
|
||
def _sample(a, noise): | ||
|
@@ -252,11 +299,15 @@ def _sample(a, noise): | |
return a | ||
|
||
noise, noise_state = self._noise_process(state.noise) | ||
if self._noise_clipping: | ||
noise = torch.clamp(noise, -self._noise_clipping, | ||
self._noise_clipping) | ||
noisy_action = nest.map_structure(_sample, action, noise) | ||
noisy_action = nest.map_structure(spec_utils.clip_to_spec, | ||
noisy_action, self._action_spec) | ||
state = empty_state._replace( | ||
noise=noise_state, | ||
ensemble_ids=ensemble_ids if self._use_batch_ensemble else (), | ||
actor=DdpgActorState(actor=actor_state, critics=())) | ||
|
||
return AlgStep( | ||
|
@@ -265,9 +316,12 @@ def _sample(a, noise): | |
info=DdpgInfo(action=noisy_action, action_distribution=action)) | ||
|
||
def rollout_step(self, time_step: TimeStep, state: DdpgState = None): | ||
if self.need_full_rollout_state(): | ||
raise NotImplementedError("Storing RNN state to replay buffer " | ||
"is not supported by DdpgAlgorithm") | ||
"""``rollout_step()`` basically predicts actions like what is done by | ||
``predict_step()``. Additionally, if states are to be stored a in replay | ||
buffer, then this function also call ``_critic_networks``, | ||
``_target_critic_networks``, and ``_target_actor_network`` to maintain | ||
their states. | ||
""" | ||
|
||
def _update_random_action(spec, noisy_action): | ||
random_action = spec_utils.scale_to_spec( | ||
|
@@ -277,18 +331,52 @@ def _update_random_action(spec, noisy_action): | |
_rollout_random_action) | ||
noisy_action[ind[0], :] = random_action[ind[0], :] | ||
|
||
observation = time_step.observation | ||
if self._use_batch_ensemble and torch.count_nonzero( | ||
state.ensemble_ids) > 0: | ||
# If use_batch_ensemble, we want to use the same ensemble_ids | ||
# to forward the actor_network during the rollout of an episode, | ||
# except for the initial rollout_step, where the ensemble_ids | ||
# in the initial rollout_state are all zeros. | ||
time_step = time_step._replace( | ||
observation=(observation, state.ensemble_ids)) | ||
pred_step = self._predict_step(time_step, state, epsilon_greedy=1.0) | ||
if self._rollout_random_action > 0: | ||
nest.map_structure(_update_random_action, self._action_spec, | ||
pred_step.output) | ||
return pred_step | ||
|
||
if self.need_full_rollout_state(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want the algorithm use the same ensemble_id during an entire episode. This means that it should store ensembled_id in state and use the same ensemble_id to call actor_network There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes, good point, I think that is the reason why I had to tweak the |
||
_, critics_state = self._critic_networks( | ||
(observation, pred_step.output), state.critics.critics) | ||
_, target_critics_state = self._target_critic_networks( | ||
(observation, pred_step.output), state.critics.target_critics) | ||
_, target_actor_state = self._target_actor_network( | ||
observation, state=state.critics.target_actor) | ||
critic_state = DdpgCriticState( | ||
critics=critics_state, | ||
target_actor=target_actor_state, | ||
target_critics=target_critics_state) | ||
else: | ||
critics_state = state.critics.critics | ||
critic_state = state.critics | ||
|
||
actor_state = pred_step.state.actor._replace(critics=critics_state) | ||
|
||
new_state = pred_step.state._replace( | ||
actor=actor_state, critics=critic_state) | ||
|
||
return pred_step._replace(state=new_state) | ||
|
||
def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, | ||
rollout_info: DdpgInfo): | ||
target_action, target_actor_state = self._target_actor_network( | ||
inputs.observation, state=state.target_actor) | ||
if self._use_batch_ensemble: | ||
target_action = target_action[0] | ||
target_q_values, target_critic_states = self._target_critic_networks( | ||
(inputs.observation, target_action), state=state.target_critics) | ||
if self._use_batch_ensemble: | ||
target_q_values = target_q_values[0] | ||
|
||
if self.has_multidim_reward(): | ||
sign = self.reward_weights.sign() | ||
|
@@ -298,6 +386,8 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, | |
|
||
q_values, critic_states = self._critic_networks( | ||
(inputs.observation, rollout_info.action), state=state.critics) | ||
if self._use_batch_ensemble: | ||
q_values = q_values[0] | ||
|
||
state = DdpgCriticState( | ||
critics=critic_states, | ||
|
@@ -312,9 +402,13 @@ def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState, | |
def _actor_train_step(self, inputs: TimeStep, state: DdpgActorState): | ||
action, actor_state = self._actor_network( | ||
inputs.observation, state=state.actor) | ||
if self._use_batch_ensemble: | ||
action = action[0] | ||
|
||
q_values, critic_states = self._critic_networks( | ||
(inputs.observation, action), state=state.critics) | ||
if self._use_batch_ensemble: | ||
q_values = q_values[0] | ||
if self.has_multidim_reward(): | ||
# Multidimensional reward: [B, replicas, reward_dim] | ||
q_values = q_values * self.reward_weights | ||
|
@@ -343,9 +437,22 @@ def actor_loss_fn(dqda, action): | |
|
||
def train_step(self, inputs: TimeStep, state: DdpgState, | ||
rollout_info: DdpgInfo): | ||
self._train_step_count += 1 | ||
critic_states, critic_info = self._critic_train_step( | ||
inputs=inputs, state=state.critics, rollout_info=rollout_info) | ||
policy_step = self._actor_train_step(inputs=inputs, state=state.actor) | ||
if self._train_step_count % self._actor_update_period == 0: | ||
policy_step = self._actor_train_step( | ||
inputs=inputs, state=state.actor) | ||
critic_states = critic_states._replace( | ||
critics=policy_step.state.critics) | ||
else: | ||
batch_dims = nest_utils.get_outer_rank(inputs.prev_action, | ||
self._action_spec) | ||
loss = torch.zeros(*inputs.prev_action.shape[:batch_dims]) | ||
policy_step = AlgStep( | ||
output=torch.zeros_like(inputs.prev_action), | ||
state=state.actor, | ||
info=LossInfo(loss=loss, extra=loss)) | ||
return policy_step._replace( | ||
state=state._replace( | ||
actor=policy_step.state, critics=critic_states), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we might should make these batch ensemble related parameters transparent to the ddpg_algorithm? Basically, the ddpg_algorithm should not use batch_ensemble related parameters in the ideal case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. Currently ddpg needs the
use_batch_ensemble
to do some post processing when forwarding critic networks during training. Let me think it over if there might be some alternative methods to work around.