Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lower bounded target q #1330

Open
wants to merge 8 commits into
base: pytorch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions alf/algorithms/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,14 +736,25 @@ class HindsightExperienceTransformer(DataTransformer):
of the current timestep.
The exact field names can be provided via arguments to the class ``__init__``.

NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.

To use this class, add it to any existing data transformers, e.g. use this config if
``ObservationNormalizer`` is an existing data transformer:

.. code-block:: python

ReplayBuffer.keep_episodic_info=True
HindsightExperienceTransformer.her_proportion=0.8
TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
alf.config('ReplayBuffer', keep_episodic_info=True)
alf.config(
'HindsightExperienceTransformer',
her_proportion=0.8
)
alf.config(
'TrainerConfig',
data_transformer_ctor=[
HindsightExperienceTransformer, ObservationNormalizer
])

See unit test for more details on behavior.
"""
Expand Down Expand Up @@ -820,9 +831,10 @@ def transform_experience(self, experience: Experience):
# relabel only these sampled indices
her_cond = torch.rand(batch_size) < her_proportion
(her_indices, ) = torch.where(her_cond)
has_her = torch.any(her_cond)

last_step_pos = start_pos[her_indices] + batch_length - 1
last_env_ids = env_ids[her_indices]
last_step_pos = start_pos + batch_length - 1
last_env_ids = env_ids
# Get x, y indices of LAST steps
dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids)
if alf.summary.should_record_summaries():
Expand All @@ -831,22 +843,24 @@ def transform_experience(self, experience: Experience):
torch.mean(dist.type(torch.float32)))

# get random future state
future_idx = last_step_pos + (torch.rand(*dist.shape) *
(dist + 1)).to(torch.int64)
future_dist = (torch.rand(*dist.shape) * (dist + 1)).to(
torch.int64)
future_idx = last_step_pos + future_dist
future_ag = buffer.get_field(self._achieved_goal_field,
last_env_ids, future_idx).unsqueeze(1)

# relabel desired goal
result_desired_goal = alf.nest.get_field(result,
self._desired_goal_field)
relabed_goal = result_desired_goal.clone()
relabeled_goal = result_desired_goal.clone()
her_batch_index_tuple = (her_indices.unsqueeze(1),
torch.arange(batch_length).unsqueeze(0))
relabed_goal[her_batch_index_tuple] = future_ag
if has_her:
relabeled_goal[her_batch_index_tuple] = future_ag[her_indices]

# recompute rewards
result_ag = alf.nest.get_field(result, self._achieved_goal_field)
relabeled_rewards = self._reward_fn(result_ag, relabed_goal)
relabeled_rewards = self._reward_fn(result_ag, relabeled_goal)

non_her_or_fst = ~her_cond.unsqueeze(1) & (result.step_type !=
StepType.FIRST)
Expand Down Expand Up @@ -876,21 +890,26 @@ def transform_experience(self, experience: Experience):
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_before_relabel",
torch.mean(result.reward[her_indices][:-1]))
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_after_relabel",
torch.mean(relabeled_rewards[her_indices][:-1]))
if has_her:
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_after_relabel",
torch.mean(relabeled_rewards[her_indices][:-1]))
alf.summary.scalar("replayer/" + buffer._name + ".future_distance",
torch.mean(future_dist.float()))

result = alf.nest.transform_nest(
result, self._desired_goal_field, lambda _: relabed_goal)

result, self._desired_goal_field, lambda _: relabeled_goal)
result = result.update_time_step_field('reward', relabeled_rewards)

derived = {"is_her": her_cond, "future_distance": future_dist}
if alf.get_default_device() != buffer.device:
for f in accessed_fields:
result = alf.nest.transform_nest(
result, f, lambda t: convert_device(t))
result = alf.nest.transform_nest(
result, "batch_info.replay_buffer", lambda _: buffer)
info = convert_device(info)
derived = convert_device(derived)
info = info._replace(replay_buffer=buffer)
info = info.set_derived(derived)
result = alf.data_structures.add_batch_info(result, info)
return result


Expand Down
10 changes: 8 additions & 2 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor_loss", "critic", "discounted_return"
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor_loss",
"critic",
"discounted_return",
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
Expand Down
160 changes: 160 additions & 0 deletions alf/algorithms/her_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HER Algorithms (Wrappers)."""
"""Classes defined here are used to transfer relevant info about the
sampled/replayed experience from HindsightDataTransformer all the way to
algorithm.calc_loss and the loss class.

Actual hindsight relabeling happens in HindsightDataTransformer.

For usage, see alf/examples/her_fetchpush_conf.py.
"""

import alf
from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo
from alf.data_structures import TimeStep
from alf.utils import common


def her_wrapper(alg_cls, alg_info):
"""A helper function to construct HerAlgo based on the base (off-policy) algorithm.

We mainly do two things here:
1. Create the new HerInfo namedtuple, containing a ``derived`` field together
with the existing fields of AlgInfo. The ``derived`` field is a dict, to be
populated with information derived from the Hindsight relabeling process.
This HerInfo structure stores training information collected from replay and
processed by the algorithm's train_step.

2. Create a new HerAlgo child class of the input base algorithm.
The new class additionally handles passing derived fields along the pipeline
for the loss function (e.g. LowerboundedTDLoss) to access.
"""
HerClsName = "Her" + alg_cls.__name__
# HerAlgo class inherits the base RL algorithm class
HerCls = type(HerClsName, (alg_cls, ), {})
HerCls.counter = 0

HerInfoName = "Her" + alg_info.__name__
# Unfortunately, the user has to ensure that the default_value of HerAlgInfo has to be
# exactly the same as the AlgInfo, otherwise there could be bugs.
HerInfoCls = alf.data_structures.namedtuple(
HerInfoName, alg_info._fields + ("derived", ), default_value=())
alg_info.__name__ = HerInfoName

# NOTE: replay_buffer.py has similar functions for handling BatchInfo namedtuple.

# New __new__ for AlgInfo, so every time AlgInfo is called to create an instance,
# an HerAlgInfo instance (with the additional ``derived`` dict) is created and
# returned instead. This allows us to wrap an algorithm's AlgInfo class without
# changing any code in the original AlgInfo class, keeping HER code separate.
@common.add_method(alg_info)
def __new__(info_cls, **kwargs):
assert info_cls == alg_info
her_info = HerInfoCls(**kwargs)
# Set default value, later code will check for this
her_info = her_info._replace(derived={})
return her_info

# New accessor methods for HerAlgInfo to access the ``derived`` dict.
@common.add_method(HerInfoCls)
def get_derived_field(self, field):
assert field in self.derived, f"field {field} not in BatchInfo.derived"
return self.derived[field]

@common.add_method(HerInfoCls)
def get_derived(self):
return self.derived

@common.add_method(HerInfoCls)
def set_derived(self, new_dict):
assert self.derived == {}
return self._replace(derived=new_dict)

# New methods for HerAlg
@common.add_method(HerCls)
def __init__(self, **kwargs):
"""
Args:
kwargs: arguments passed to the constructor of the underlying algorithm.
"""
assert HerCls.counter == 0, f"HerCls {HerCls} already defined"
super(HerCls, self).__init__(**kwargs)
HerCls.counter += 1

@common.add_method(HerCls)
def preprocess_experience(self, inputs: TimeStep, rollout_info: alg_info,
batch_info):
"""Pass derived fields from batch_info into rollout_info"""
time_step, rollout_info = super(HerCls, self).preprocess_experience(
inputs, rollout_info, batch_info)
if hasattr(rollout_info, "derived") and batch_info.derived:
# Expand to the proper dimensions consistent with other experience fields
derived = alf.nest.map_structure(
lambda x: x.unsqueeze(1).expand(time_step.reward.shape[:2]),
batch_info.get_derived())
rollout_info = rollout_info.set_derived(derived)
return time_step, rollout_info

@common.add_method(HerCls)
def train_step(self, inputs: TimeStep, state, rollout_info: alg_info):
"""Pass derived fields from rollout_info into alg_step.info"""
alg_step = super(HerCls, self).train_step(inputs, state, rollout_info)
return alg_step._replace(
info=alg_step.info.set_derived(rollout_info.get_derived()))

return HerCls # End of her_wrapper function


# Create the actual wrapped HerAlgorithms
HerSacAlgorithm = her_wrapper(SacAlgorithm, SacInfo)
HerDdpgAlgorithm = her_wrapper(DdpgAlgorithm, DdpgInfo)
"""To help understand what's going on, here is the detailed data flow:

1. Replayer samples the experience with batch_info from replay_buffer.

2. HindsightDataTransformer samples and relabels the experience, stores the derived info containing
her: whether the experience has been relabeled, future_distance: the number of time steps to
the future achieved goal used to relabel the experience.
HindsightDataTransformer finally returns experience with experience.batch_info.derived
containing the derived information.

(NOTE: we cannot put HindsightDataTransformer into HerAlgo.preprocess_experience, as preprocessing
happens after data_transformations, but Hindsight relabeling has to happen before other data
transformations like observation normalization, because hindsight accesses replay_buffer data directly,
which has not gone through the data transformers.
Maybe we could invoke HindsightDataTransformer automatically, e.g. by preprending it to
``TrainConfig.data_transformer_ctr`` in this file. Maybe that's too magical, and should be avoided.)

3. HerAlgo.preprocess_experience copies ``batch_info.derived`` over to ``rollout_info.derived``.
NOTE: We cannot copy from exp to rollout_info because the input to preprocess_experience is time_step,
not exp in algorithm.py:

.. code-block:: python

time_step, rollout_info = self.preprocess_experience(
experience.time_step, experience.rollout_info, batch_info)

4. HerAlgo.train_step copies ``exp.rollout_info.derived`` over to ``policy_step.info.derived``.
NOTE: we cannot just copy derived from exp into AlgInfo in train_step, because train_step accepts
time_step instead of exp as input:

.. code-block:: python

policy_step = self.train_step(exp.time_step, policy_state,
exp.rollout_info)

5. BaseAlgo.calc_loss will call LowerboundedTDLoss with HerBaseAlgoInfo.
"""
46 changes: 46 additions & 0 deletions alf/algorithms/her_algorithms_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized

import alf
from alf.algorithms.her_algorithms import HerSacAlgorithm, HerDdpgAlgorithm
from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo


class HerAlgorithmsTest(parameterized.TestCase, alf.test.TestCase):
def test_her_algo_name(self):
self.assertEqual("HerSacAlgorithm", HerSacAlgorithm.__name__)
self.assertEqual("HerDdpgAlgorithm", HerDdpgAlgorithm.__name__)

@parameterized.parameters([
(SacInfo, ),
(DdpgInfo, ),
])
def test_her_info(self, Info):
info = Info(reward=1)
self.assertEqual(1, info.reward)
# HerAlgInfo assumes default field value to be (), need to be consistent with AlgInfo
self.assertEqual((), info.action)
self.assertEqual({}, info.get_derived())
ret = info.set_derived({"a": 1, "b": 2})
# info is immutable
self.assertEqual({}, info.get_derived())
# ret is the new instance with field "derived" replaced
self.assertEqual(1, ret.get_derived_field("a"))
self.assertEqual(2, ret.get_derived_field("b"))
# get nonexistent field with and without default
self.assertEqual("none", ret.get_derived_field("x", default="none"))
self.assertRaises(AssertionError, ret.get_derived_field, "x")
12 changes: 10 additions & 2 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,16 @@

SacInfo = namedtuple(
"SacInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor", "critic", "alpha", "log_pi", "discounted_return"
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor",
"critic",
"alpha",
"log_pi",
"discounted_return",
],
default_value=())

Expand Down
Loading