Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch ensemble ddpg #1633

Open
wants to merge 5 commits into
base: pytorch
Choose a base branch
from
Open

Batch ensemble ddpg #1633

wants to merge 5 commits into from

Conversation

runjerry
Copy link
Contributor

Update actor_network, critic_network, and ddpg_algorithm to work with batch_ensemble layers.

@runjerry runjerry marked this pull request as ready for review March 29, 2024 22:02
@@ -139,6 +148,17 @@ def __init__(self,
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``.
Does not perform clipping if ``dqda_clipping == 0``.
action_l2 (float): weight of squared action l2-norm on actor loss.
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we might should make these batch ensemble related parameters transparent to the ddpg_algorithm? Basically, the ddpg_algorithm should not use batch_ensemble related parameters in the ideal case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Currently ddpg needs the use_batch_ensemble to do some post processing when forwarding critic networks during training. Let me think it over if there might be some alternative methods to work around.

@@ -281,14 +318,39 @@ def _update_random_action(spec, noisy_action):
if self._rollout_random_action > 0:
nest.map_structure(_update_random_action, self._action_spec,
pred_step.output)
return pred_step

if self.need_full_rollout_state():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want the algorithm use the same ensemble_id during an entire episode. This means that it should store ensembled_id in state and use the same ensemble_id to call actor_network

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, good point, I think that is the reason why I had to tweak the ddpg_algorithm_test to pass the toy unittest. Updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants