Skip to content

Commit

Permalink
introduce her_algorithms wrapper to replace ad hoc changes to base (o…
Browse files Browse the repository at this point in the history
…ff-policy) algorithms
  • Loading branch information
Le Horizon committed Jun 30, 2022
1 parent b9c4143 commit 6d5b225
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 41 deletions.
10 changes: 4 additions & 6 deletions alf/algorithms/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,17 +900,15 @@ def transform_experience(self, experience: Experience):
result = alf.nest.transform_nest(
result, self._desired_goal_field, lambda _: relabeled_goal)
result = result.update_time_step_field('reward', relabeled_rewards)
info = info._replace(her=her_cond, future_distance=future_dist)
derived = {"her": her_cond, "future_distance": future_dist}
if alf.get_default_device() != buffer.device:
for f in accessed_fields:
result = alf.nest.transform_nest(
result, f, lambda t: convert_device(t))
info = convert_device(info)
info = info._replace(
her=info.her.unsqueeze(1).expand(result.reward.shape[:2]),
future_distance=info.future_distance.unsqueeze(1).expand(
result.reward.shape[:2]),
replay_buffer=buffer)
derived = convert_device(derived)
info = info._replace(replay_buffer=buffer)
info = info.set_derived(derived)
result = alf.data_structures.add_batch_info(result, info)
return result

Expand Down
7 changes: 1 addition & 6 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,15 @@
DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics'])
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo",
[
"DdpgInfo", [
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor_loss",
"critic",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
Expand Down
164 changes: 164 additions & 0 deletions alf/algorithms/her_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HER Algorithms (Wrappers)."""
"""Classes defined here are used to transfer relevant info about the
sampled/replayed experience from HindsightDataTransformer all the way to
algorithm.calc_loss and the loss class.
Actual hindsight relabeling happens in HindsightDataTransformer.
For usage, see alf/examples/her_fetchpush_conf.py.
"""

import alf
from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo
from alf.data_structures import TimeStep
from alf.utils import common


def her_wrapper(alg_cls, alg_info):
"""A helper function to construct HerAlgo based on the base (off-policy) algorithm.
We mainly do two things here:
1. Create the new HerInfo namedtuple, containing a ``derived`` field together
with the existing fields of AlgInfo. The ``derived`` field is a dict, to be
populated with information derived from the Hindsight relabeling process.
This HerInfo structure stores training information collected from replay and
processed by the algorithm's train_step.
2. Create a new HerAlgo child class of the input base algorithm.
The new class additionally handles passing derived fields along the pipeline
for the loss function (e.g. LowerboundedTDLoss) to access.
"""
HerClsName = "Her" + alg_cls.__name__
# HerAlgo class inherits the base RL algorithm class
HerCls = type(HerClsName, (alg_cls, ), {})
HerCls.counter = 0

HerInfoName = "Her" + alg_info.__name__
# Unfortunately, the user has to ensure that the default_value of HerAlgInfo has to be
# exactly the same as the AlgInfo, otherwise there could be bugs.
HerInfoCls = alf.data_structures.namedtuple(
HerInfoName, alg_info._fields + ("derived", ), default_value=())
alg_info.__name__ = HerInfoName

# NOTE: replay_buffer.py has similar functions for handling BatchInfo namedtuple.

# New __new__ for AlgInfo, so every time AlgInfo is called to create an instance,
# an HerAlgInfo instance (with the additional ``derived`` dict) is created and
# returned instead. This allows us to wrap an algorithm's AlgInfo class without
# changing any code in the original AlgInfo class, keeping HER code separate.
@common.add_method(alg_info)
def __new__(info_cls, **kwargs):
assert info_cls == alg_info
her_info = HerInfoCls(**kwargs)
# Set default value, later code will check for this
her_info = her_info._replace(derived={})
return her_info

# New accessor methods for HerAlgInfo to access the ``derived`` dict.
@common.add_method(HerInfoCls)
def get_derived_field(self, field, default=None):
if default is None:
assert field in self.derived, f"field {field} not in BatchInfo.derived"
if field in self.derived:
return self.derived[field]
else:
return default

@common.add_method(HerInfoCls)
def get_derived(self):
return self.derived

@common.add_method(HerInfoCls)
def set_derived(self, new_dict):
assert self.derived == {}
return self._replace(derived=new_dict)

# New methods for HerAlg
@common.add_method(HerCls)
def __init__(self, **kwargs):
"""
Args:
kwargs: arguments passed to the constructor of the underlying algorithm.
"""
assert HerCls.counter == 0, f"HerCls {HerCls} already defined"
super(HerCls, self).__init__(**kwargs)
HerCls.counter += 1

@common.add_method(HerCls)
def preprocess_experience(self, inputs: TimeStep, rollout_info: alg_info,
batch_info):
"""Pass derived fields from batch_info into rollout_info"""
time_step, rollout_info = super(HerCls, self).preprocess_experience(
inputs, rollout_info, batch_info)
if hasattr(rollout_info, "derived") and batch_info.derived:
# Expand to the proper dimensions consistent with other experience fields
derived = alf.nest.map_structure(
lambda x: x.unsqueeze(1).expand(time_step.reward.shape[:2]),
batch_info.get_derived())
rollout_info = rollout_info.set_derived(derived)
return time_step, rollout_info

@common.add_method(HerCls)
def train_step(self, inputs: TimeStep, state, rollout_info: alg_info):
"""Pass derived fields from rollout_info into alg_step.info"""
alg_step = super(HerCls, self).train_step(inputs, state, rollout_info)
return alg_step._replace(
info=alg_step.info.set_derived(rollout_info.get_derived()))

return HerCls # End of her_wrapper function


# Create the actual wrapped HerAlgorithms
HerSacAlgorithm = her_wrapper(SacAlgorithm, SacInfo)
HerDdpgAlgorithm = her_wrapper(DdpgAlgorithm, DdpgInfo)
"""To help understand what's going on, here is the detailed data flow:
1. Replayer samples the experience with batch_info from replay_buffer.
2. HindsightDataTransformer samples and relabels the experience, stores the derived info containing
her: whether the experience has been relabeled, future_distance: the number of time steps to
the future achieved goal used to relabel the experience.
HindsightDataTransformer finally returns experience with experience.batch_info.derived
containing the derived information.
(NOTE: we cannot put HindsightDataTransformer into HerAlgo.preprocess_experience, as preprocessing
happens after data_transformations, but Hindsight relabeling has to happen before other data
transformations like observation normalization, because hindsight accesses replay_buffer data directly,
which has not gone through the data transformers.
Maybe we could invoke HindsightDataTransformer automatically, e.g. by preprending it to
``TrainConfig.data_transformer_ctr`` in this file. Maybe that's too magical, and should be avoided.)
3. HerAlgo.preprocess_experience copies ``batch_info.derived`` over to ``rollout_info.derived``.
NOTE: We cannot copy from exp to rollout_info because the input to preprocess_experience is time_step,
not exp in algorithm.py:
.. code-block:: python
time_step, rollout_info = self.preprocess_experience(
experience.time_step, experience.rollout_info, batch_info)
4. HerAlgo.train_step copies ``exp.rollout_info.derived`` over to ``policy_step.info.derived``.
NOTE: we cannot just copy derived from exp into AlgInfo in train_step, because train_step accepts
time_step instead of exp as input:
.. code-block:: python
policy_step = self.train_step(exp.time_step, policy_state,
exp.rollout_info)
5. BaseAlgo.calc_loss will call LowerboundedTDLoss with HerBaseAlgoInfo.
"""
46 changes: 46 additions & 0 deletions alf/algorithms/her_algorithms_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized

import alf
from alf.algorithms.her_algorithms import HerSacAlgorithm, HerDdpgAlgorithm
from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo


class HerAlgorithmsTest(parameterized.TestCase, alf.test.TestCase):
def test_her_algo_name(self):
self.assertEqual("HerSacAlgorithm", HerSacAlgorithm.__name__)
self.assertEqual("HerDdpgAlgorithm", HerDdpgAlgorithm.__name__)

@parameterized.parameters([
(SacInfo, ),
(DdpgInfo, ),
])
def test_her_info(self, Info):
info = Info(reward=1)
self.assertEqual(1, info.reward)
# HerAlgInfo assumes default field value to be (), need to be consistent with AlgInfo
self.assertEqual((), info.action)
self.assertEqual({}, info.get_derived())
ret = info.set_derived({"a": 1, "b": 2})
# info is immutable
self.assertEqual({}, info.get_derived())
# ret is the new instance with field "derived" replaced
self.assertEqual(1, ret.get_derived_field("a"))
self.assertEqual(2, ret.get_derived_field("b"))
# get nonexistent field with and without default
self.assertEqual("none", ret.get_derived_field("x", default="none"))
self.assertRaises(AssertionError, ret.get_derived_field, "x")
7 changes: 1 addition & 6 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@
"SacActorInfo", ["actor_loss", "neg_entropy"], default_value=())

SacInfo = namedtuple(
"SacInfo",
[
"SacInfo", [
"reward",
"step_type",
"discount",
Expand All @@ -65,11 +64,7 @@
"critic",
"alpha",
"log_pi",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())

Expand Down
12 changes: 6 additions & 6 deletions alf/algorithms/td_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def compute_td_target(self,
dimesion is the batch dimension.
Args:
info (namedtuple): experience collected from ``unroll()`` or
info (namedtuple): AlgInfo collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
Expand Down Expand Up @@ -301,7 +301,7 @@ def compute_td_target(self,
dimesion is the batch dimension.
Args:
info (namedtuple): experience collected from ``unroll()`` or
info (namedtuple): AlgInfo collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
Expand Down Expand Up @@ -370,12 +370,12 @@ def compute_td_target(self,
torch.mean(value[:-1][:, episode_ended[0, :]]))

if self._lb_target_q > 0 and disc_ret != ():
her_cond = info.her
her_cond = info.get_derived_field("her", default=())
mask = torch.ones(returns.shape, dtype=torch.bool)
if her_cond != () and torch.any(~her_cond):
mask = ~her_cond[:-1]
disc_ret = disc_ret[
1:] # it's expanded in ddpg_algorithm, need to revert back.
1:] # it's expanded in Agent.preprocess_experience, need to revert back.
assert returns.shape == disc_ret.shape, "%s %s" % (returns.shape,
disc_ret.shape)
with alf.summary.scope(self._name):
Expand All @@ -391,9 +391,9 @@ def compute_td_target(self,

if self._improve_w_goal_return:
batch_length, batch_size = returns.shape[:2]
her_cond = info.her
her_cond = info.get_derived_field("her")
if her_cond != () and torch.any(her_cond):
dist = info.future_distance
dist = info.get_derived_field("future_distance")
if self._positive_reward:
goal_return = torch.pow(
self._gamma * torch.ones(her_cond.shape), dist)
Expand Down
3 changes: 3 additions & 0 deletions alf/examples/her_fetchpush_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from alf.algorithms.data_transformer import HindsightExperienceTransformer, \
ObservationNormalizer
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm
from alf.algorithms.her_algorithms import HerDdpgAlgorithm
from alf.environments import suite_robotics
from alf.nest.utils import NestConcat

Expand All @@ -38,6 +39,8 @@

alf.config('DdpgAlgorithm', action_l2=0.05)

alf.config('Agent', rl_algorithm_cls=HerDdpgAlgorithm)

# Finer grain tensorboard summaries plus local action distribution
# TrainerConfig.summarize_action_distributions=True
# TrainerConfig.summary_interval=1
Expand Down
Loading

0 comments on commit 6d5b225

Please sign in to comment.