Skip to content

Commit

Permalink
CQL algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Nov 21, 2022
1 parent 0f8d0ec commit 4eab0a1
Show file tree
Hide file tree
Showing 4 changed files with 481 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Read the ALF documentation [here](https://alf.readthedocs.io/).
|[MuZero](alf/algorithms/muzero_algorithm.py)|Model-based RL|Schrittwieser et al. "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model" [arXiv:1911.08265](https://arxiv.org/abs/1911.08265)|
|[BC](alf/algorithms/bc_algorithm.py)|Offline RL|Pomerleau "ALVINN: An Autonomous Land Vehicle in a Neural Network" [NeurIPS 1988](https://papers.nips.cc/paper/1988/hash/812b4ba287f5ee0bc9d43bbf5bbe87fb-Abstract.html) <br> Bain et al. "A framework for behavioural cloning" [Machine Intelligence 1999](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf)|
|[Causal BC](alf/algorithms/causal_bc_algorithm.py)|Offline RL|Swamy et al. "Causal Imitation Learning under Temporally Correlated Noise" [ICML2022](https://proceedings.mlr.press/v162/swamy22a/swamy22a.pdf)|
|[CQL](alf/algorithms/cql_algorithm.py)|Offline RL|Kumar, et al. "Conservative Q-Learning for Offline Reinforcement Learning" [arXiv:2006.04779](https://arxiv.org/abs/2006.04779)|
|[IQL](alf/algorithms/iql_algorithm.py)|Offline RL|Kostrikov, et al. "Offline Reinforcement Learning with Implicit Q-Learning" [arXiv:2110.06169](https://arxiv.org/abs/2110.06169)|
|[MERLIN](alf/algorithms/merlin_algorithm.py)|Unsupervised learning|Wayne et al. "Unsupervised Predictive Memory in a Goal-Directed Agent"[arXiv:1803.10760](https://arxiv.org/abs/1803.10760)|
|[MoNet](alf/algorithms/monet_algorithm.py)|Unsupervised learning|Burgess et al. "MONet: Unsupervised Scene Decomposition and Representation" [arXiv:1901.11390](https://arxiv.org/abs/1901.11390)|
Expand Down
375 changes: 375 additions & 0 deletions alf/algorithms/cql_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conservative Q-Learning Algorithm."""

import math
import torch
import torch.nn as nn

import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.sac_algorithm import (SacAlgorithm, SacActionState,
SacCriticState, SacInfo)
from alf.data_structures import TimeStep, LossInfo, namedtuple
from alf.nest import nest
from alf.networks import ActorDistributionNetwork, CriticNetwork
from alf.networks import QNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import math_ops

CqlCriticInfo = namedtuple(
"CqlCriticInfo",
["critics", "target_critic", "min_q_loss", "alpha_prime_loss"])

CqlLossInfo = namedtuple(
'CqlLossInfo', ('actor', 'critic', 'alpha', "cql", "alpha_prime_loss"))


@alf.configurable
class CqlAlgorithm(SacAlgorithm):
r"""Cql algorithm, described in:
::
Kumar et al. "Conservative Q-Learning for Offline Reinforcement Learning",
arXiv:2006.04779
The main idea is to learn a Q-function with an additional regularizer that
penalizes the Q-values for out-of-distribution actions. It can be shown that
the expected value of a policy under this Q-function lower-bounds its true
value.
"""

def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorDistributionNetwork,
critic_network_cls=CriticNetwork,
q_network_cls=QNetwork,
reward_weights=None,
epsilon_greedy=None,
use_entropy_reward=False,
normalize_entropy_reward=False,
calculate_priority=False,
num_critic_replicas=2,
env=None,
config: TrainerConfig = None,
critic_loss_ctor=None,
target_entropy=None,
prior_actor_ctor=None,
target_kld_per_dim=3.,
initial_log_alpha=0.0,
max_log_alpha=None,
target_update_tau=0.05,
target_update_period=1,
dqda_clipping=None,
actor_optimizer=None,
critic_optimizer=None,
alpha_optimizer=None,
alpha_prime_optimizer=None,
debug_summaries=False,
cql_type="H",
cql_action_replica=10,
cql_temperature=1.0,
cql_regularization_weight=1.0,
cql_target_value_gap=-1,
initial_log_alpha_prime=0,
name="CqlAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (nested BoundedTensorSpec): representing the actions; can
be a mixture of discrete and continuous actions. The number of
continuous actions can be arbitrary while only one discrete
action is allowed currently. If it's a mixture, then it must be
a tuple/list ``(discrete_action_spec, continuous_action_spec)``.
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
actor_network_cls (Callable): is used to construct the actor network.
The constructed actor network will be called
to sample continuous actions. All of its output specs must be
continuous. Note that we don't need a discrete actor network
because a discrete action can simply be sampled from the Q values.
critic_network_cls (Callable): is used to construct critic network.
for estimating ``Q(s,a)`` given that the action is continuous.
q_network (Callable): is used to construct QNetwork for estimating ``Q(s,a)``
given that the action is discrete. Its output spec must be consistent with
the discrete action in ``action_spec``.
reward_weights (None|list[float]): this is only used when the reward is
multidimensional. In that case, the weighted sum of the q values
is used for training the actor if reward_weights is not None.
Otherwise, the sum of the q values is used.
epsilon_greedy (float): a floating value in [0,1], representing the
chance of action sampling instead of taking argmax. This can
help prevent a dead loop in some deterministic environment like
Breakout. Only used for evaluation. If None, its value is taken
from ``config.epsilon_greedy`` and then
``alf.get_config_value(TrainerConfig.epsilon_greedy)``.
use_entropy_reward (bool): whether to include entropy as reward
normalize_entropy_reward (bool): if True, normalize entropy reward
to reduce bias in episodic cases. Only used if
``use_entropy_reward==True``.
calculate_priority (bool): whether to calculate priority. This is
only useful if priority replay is enabled.
num_critic_replicas (int): number of critics to be used. Default is 2.
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
algorithm.
config (TrainerConfig): config for training. It only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
initial_log_alpha (float): initial value for variable ``log_alpha``.
max_log_alpha (float|None): if not None, ``log_alpha`` will be
capped at this value.
target_entropy (float|Callable|None): If a floating value, it's the
target average policy entropy, for updating ``alpha``. If a
callable function, then it will be called on the action spec to
calculate a target entropy. If ``None``, a default entropy will
be calculated. For the mixed action type, discrete action and
continuous action will have separate alphas and target entropies,
so this argument can be a 2-element list/tuple, where the first
is for discrete action and the second for continuous action.
prior_actor_ctor (Callable): If provided, it will be called using
``prior_actor_ctor(observation_spec, action_spec, debug_summaries=debug_summaries)``
to constructor a prior actor. The output of the prior actor is
the distribution of the next action. Two prior actors are implemented:
``alf.algorithms.prior_actor.SameActionPriorActor`` and
``alf.algorithms.prior_actor.UniformPriorActor``.
target_kld_per_dim (float): ``alpha`` is dynamically adjusted so that
the KLD is about ``target_kld_per_dim * dim``.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
dqda_clipping (float): when computing the actor loss, clips the
gradient dqda element-wise between
``[-dqda_clipping, dqda_clipping]``. Will not perform clipping if
``dqda_clipping == 0``.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
alpha_optimizer (torch.optim.optimizer): The optimizer for alpha.
alpha_prime_optimizer (torch.optim.optimizer): The optimizer for
alpha_prime, which is the trad-off weight for the conservative
regularization term.
debug_summaries (bool): True if debug summaries should be created.
cql_type (str): the type of CQL formulation: ``H`` (Eqn.(4) in IQL paper)
or ``rho``.
cql_action_replica (int): the number of actions to be generated for a
single observation.
cql_temperature (float): the temperature parameter for scaling Q
before applying the log-sum-exp operator.
cql_regularization_weight (float): the weight of the cql regularization
term beforing being added to the total loss
cql_target_value_gap (float): the target value gap between the softmax
Q value and average Q value. The ``prime_alpha`` parameter is
adjusted to match this target gap. A negative value corresponds
to not enabling auto-adjusting of ``prime_alpha``.
name (str): The name of this algorithm.
"""

super().__init__(
observation_spec=observation_spec,
action_spec=action_spec,
reward_spec=reward_spec,
actor_network_cls=actor_network_cls,
critic_network_cls=critic_network_cls,
q_network_cls=q_network_cls,
reward_weights=reward_weights,
epsilon_greedy=epsilon_greedy,
use_entropy_reward=use_entropy_reward,
normalize_entropy_reward=normalize_entropy_reward,
calculate_priority=calculate_priority,
num_critic_replicas=num_critic_replicas,
env=env,
config=config,
critic_loss_ctor=critic_loss_ctor,
target_entropy=target_entropy,
prior_actor_ctor=prior_actor_ctor,
target_kld_per_dim=target_kld_per_dim,
initial_log_alpha=initial_log_alpha,
max_log_alpha=max_log_alpha,
target_update_tau=target_update_tau,
target_update_period=target_update_period,
dqda_clipping=dqda_clipping,
actor_optimizer=actor_optimizer,
critic_optimizer=critic_optimizer,
alpha_optimizer=alpha_optimizer,
debug_summaries=debug_summaries,
name=name)

assert cql_type in {'H', 'rho'}, "unknown cql_type {}".format(cql_type)

self._cql_type = cql_type
self._cql_target_value_gap = cql_target_value_gap
self._cql_action_replica = cql_action_replica
self._cql_temperature = cql_temperature
self._cql_regularization_weight = cql_regularization_weight
self._mini_batch_length = config.mini_batch_length

if self._cql_target_value_gap > 0:
self._log_alpha_prime = nn.Parameter(
torch.tensor(float(initial_log_alpha_prime)))
if alpha_prime_optimizer is not None:
self.add_optimizer(alpha_prime_optimizer,
[self._log_alpha_prime])

def _critic_train_step(self, inputs: TimeStep, state: SacCriticState,
rollout_info: SacInfo, action, action_distribution):

critic_state, critic_info = super()._critic_train_step(
inputs, state, rollout_info, action, action_distribution)
critics = critic_info.critics
target_critic = critic_info.target_critic

# ---- CQL specific regularizations ------
# repeat observation and action
# [B, d] -> [B, replica, d] -> [B * replica, d]
B = inputs.observation.shape[0]
rep_obs = inputs.observation.unsqueeze(1).repeat(
1, self._cql_action_replica, 1).view(B * self._cql_action_replica,
inputs.observation.shape[1])

# get random actions
random_actions = self._action_spec.sample(
(B * self._cql_action_replica, ))
critics_random_actions, critics_state = self._compute_critics(
self._critic_networks,
rep_obs,
random_actions,
state.critics,
replica_min=False,
apply_reward_weights=False)

if self._cql_type == "H":
random_log_probs = math.log(0.5**random_actions.shape[-1])
critics_random_actions = critics_random_actions - random_log_probs

# [B, action_replica, critic_replica]
critics_random_actions = critics_random_actions.reshape(
B, self._cql_action_replica, -1)

current_action_distribution, current_actions, _, _ = self._predict_action(
rep_obs, SacActionState())

current_log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a),
current_action_distribution,
current_actions)
current_log_pi = current_log_pi.unsqueeze(-1)

critics_current_actions, _ = self._compute_critics(
self._critic_networks,
rep_obs,
current_actions.detach(),
state.critics,
replica_min=False,
apply_reward_weights=False)

if self._cql_type == "H":
critics_current_actions = critics_current_actions - current_log_pi.detach(
)

critics_current_actions = critics_current_actions.reshape(
B, self._cql_action_replica, -1)

# This is not mentioned in the CQL paper. But in the official CQL
# implementation, an additional reqularization term based on
# Q(s_current, a_next) is also added. Here we used an approximation
# of a_next, since this is for regularization purpose.
next_actions = current_actions.reshape(
self._mini_batch_length, B // self._mini_batch_length, -1).roll(
-1, 0).reshape(-1, current_actions.shape[-1])

critics_next_actions, _ = self._compute_critics(
self._critic_networks,
rep_obs,
next_actions.detach(),
state.critics,
replica_min=False,
apply_reward_weights=False)

if self._cql_type == "H":
next_log_pi = current_log_pi.reshape(
2, B // 2, -1).flip(0).reshape(-1, current_log_pi.shape[-1])
critics_next_actions = critics_next_actions - next_log_pi.detach()

critics_next_actions = critics_next_actions.reshape(
B, self._cql_action_replica, -1)

if self._cql_type == "H":
cat_critics = torch.cat(
(critics_random_actions, critics_current_actions,
critics_next_actions),
dim=1)
else:
cat_critics = torch.cat(
(critics_random_actions, critics_current_actions,
critics.unsqueeze(1)),
dim=1)

min_q_loss = torch.logsumexp(
cat_critics / self._cql_temperature,
dim=1) * self._cql_temperature - critics

# [B, critic_replica] -> [B]
min_q_loss = min_q_loss.mean(-1) * self._cql_regularization_weight

if self._cql_target_value_gap > 0:
alpha_prime = torch.clamp(
self._log_alpha_prime.exp(), min=0.0, max=1000000.0)
q_diff = (min_q_loss - self._cql_target_value_gap)
min_q_loss = alpha_prime.detach() * q_diff
alpha_prime_loss = -0.5 * (alpha_prime * q_diff.detach())
else:
alpha_prime_loss = torch.zeros(B)

info = CqlCriticInfo(
critics=critics,
target_critic=target_critic,
min_q_loss=min_q_loss,
alpha_prime_loss=alpha_prime_loss,
)

return critic_state, info

def calc_loss(self, info: SacInfo):
critic_loss = self._calc_critic_loss(info)
alpha_loss = info.alpha
actor_loss = info.actor
min_q_loss = info.critic.min_q_loss
alpha_prime_loss = info.critic.alpha_prime_loss

if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
if self._cql_target_value_gap > 0:
alpha_prime = torch.clamp(
self._log_alpha_prime.exp(), min=0.0, max=1000000.0)
alf.summary.scalar("alpha_prime", alpha_prime)
alf.summary.scalar("min_q_loss", min_q_loss.mean())

loss = math_ops.add_ignore_empty(
actor_loss.loss,
critic_loss.loss + alpha_loss + min_q_loss + alpha_prime_loss)

return LossInfo(
loss=loss,
priority=critic_loss.priority,
extra=CqlLossInfo(
actor=actor_loss.extra,
critic=critic_loss.extra,
alpha=alpha_loss,
cql=min_q_loss,
alpha_prime_loss=alpha_prime_loss))
5 changes: 5 additions & 0 deletions alf/bin/train_play_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,11 @@ def test_causal_bc_pendulum(self):
conf_file='./hybrid_rl/causal_bc_pendulum_conf.py',
extra_train_params=OFF_POLICY_TRAIN_PARAMS)

def test_cql_pendulum(self):
self._test(
conf_file='./hybrid_rl/cql_pendulum_conf.py',
extra_train_params=OFF_POLICY_TRAIN_PARAMS)

def test_iql_pendulum(self):
self._test(
conf_file='./hybrid_rl/iql_pendulum_conf.py',
Expand Down
Loading

0 comments on commit 4eab0a1

Please sign in to comment.