Skip to content

Commit

Permalink
Make AMP implementation more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 15, 2025
1 parent e5c6b81 commit 1e00479
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions skrl/agents/torch/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.resources.schedulers.torch import KLAdaptiveLR


# fmt: off
Expand Down Expand Up @@ -256,6 +257,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
self.memory.create_tensor(name="actions", size=self.action_space, dtype=torch.float32)
self.memory.create_tensor(name="rewards", size=1, dtype=torch.float32)
self.memory.create_tensor(name="terminated", size=1, dtype=torch.bool)
self.memory.create_tensor(name="truncated", size=1, dtype=torch.bool)
self.memory.create_tensor(name="log_prob", size=1, dtype=torch.float32)
self.memory.create_tensor(name="values", size=1, dtype=torch.float32)
self.memory.create_tensor(name="returns", size=1, dtype=torch.float32)
Expand Down Expand Up @@ -383,7 +385,10 @@ def record_transition(
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
next_values, _, _ = self.value.act({"states": self._state_preprocessor(next_states)}, role="value")
next_values = self._value_preprocessor(next_values, inverse=True)
next_values *= infos["terminate"].view(-1, 1).logical_not()
if "terminate" in infos:
next_values *= infos["terminate"].view(-1, 1).logical_not() # compatibility with IsaacGymEnvs
else:
next_values *= terminated.view(-1, 1).logical_not()

self.memory.add_samples(
states=states,
Expand Down Expand Up @@ -509,6 +514,7 @@ def compute_gae(
torch.maximum(1 - 1 / (1 + torch.exp(-amp_logits)), torch.tensor(0.0001, device=self.device))
)
style_reward *= self._discriminator_reward_scale
style_reward = style_reward.view(rewards.shape)

combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_reward

Expand All @@ -517,7 +523,7 @@ def compute_gae(
next_values = self.memory.get_tensor_by_name("next_values")
returns, advantages = compute_gae(
rewards=combined_rewards,
dones=self.memory.get_tensor_by_name("terminated"),
dones=self.memory.get_tensor_by_name("terminated") | self.memory.get_tensor_by_name("truncated"),
values=values,
next_values=next_values,
discount_factor=self._discount_factor,
Expand Down Expand Up @@ -549,6 +555,7 @@ def compute_gae(

# learning epochs
for epoch in range(self._learning_epochs):
kl_divergences = []

# mini-batches loop
for batch_index, (
Expand All @@ -573,6 +580,12 @@ def compute_gae(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)

# compute approximate KL divergence
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
Expand Down Expand Up @@ -696,7 +709,15 @@ def compute_gae(

# update learning rate
if self._learning_rate_scheduler:
self.scheduler.step()
if isinstance(self.scheduler, KLAdaptiveLR):
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

# update AMP replay buffer
self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
Expand Down

0 comments on commit 1e00479

Please sign in to comment.