diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index e5ad4570..45494430 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -13,6 +13,7 @@ from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model +from skrl.resources.schedulers.torch import KLAdaptiveLR # fmt: off @@ -256,6 +257,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self.memory.create_tensor(name="actions", size=self.action_space, dtype=torch.float32) self.memory.create_tensor(name="rewards", size=1, dtype=torch.float32) self.memory.create_tensor(name="terminated", size=1, dtype=torch.bool) + self.memory.create_tensor(name="truncated", size=1, dtype=torch.bool) self.memory.create_tensor(name="log_prob", size=1, dtype=torch.float32) self.memory.create_tensor(name="values", size=1, dtype=torch.float32) self.memory.create_tensor(name="returns", size=1, dtype=torch.float32) @@ -383,7 +385,10 @@ def record_transition( with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): next_values, _, _ = self.value.act({"states": self._state_preprocessor(next_states)}, role="value") next_values = self._value_preprocessor(next_values, inverse=True) - next_values *= infos["terminate"].view(-1, 1).logical_not() + if "terminate" in infos: + next_values *= infos["terminate"].view(-1, 1).logical_not() # compatibility with IsaacGymEnvs + else: + next_values *= terminated.view(-1, 1).logical_not() self.memory.add_samples( states=states, @@ -509,6 +514,7 @@ def compute_gae( torch.maximum(1 - 1 / (1 + torch.exp(-amp_logits)), torch.tensor(0.0001, device=self.device)) ) style_reward *= self._discriminator_reward_scale + style_reward = style_reward.view(rewards.shape) combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_reward @@ -517,7 +523,7 @@ def compute_gae( next_values = self.memory.get_tensor_by_name("next_values") returns, advantages = compute_gae( rewards=combined_rewards, - dones=self.memory.get_tensor_by_name("terminated"), + dones=self.memory.get_tensor_by_name("terminated") | self.memory.get_tensor_by_name("truncated"), values=values, next_values=next_values, discount_factor=self._discount_factor, @@ -549,6 +555,7 @@ def compute_gae( # learning epochs for epoch in range(self._learning_epochs): + kl_divergences = [] # mini-batches loop for batch_index, ( @@ -573,6 +580,12 @@ def compute_gae( {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" ) + # compute approximate KL divergence + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) + # compute entropy loss if self._entropy_loss_scale: entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() @@ -696,7 +709,15 @@ def compute_gae( # update learning rate if self._learning_rate_scheduler: - self.scheduler.step() + if isinstance(self.scheduler, KLAdaptiveLR): + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) + else: + self.scheduler.step() # update AMP replay buffer self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))