Skip to content

Commit

Permalink
Add A2C mixed precision support in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 15, 2024
1 parent 898a235 commit ea9ce13
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 58 deletions.
75 changes: 46 additions & 29 deletions skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)

"mixed_precision": False, # enable automatic mixed precision for higher performance

"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
Expand Down Expand Up @@ -142,6 +144,12 @@ def __init__(
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

self._mixed_precision = self.cfg["mixed_precision"]

# set up automatic mixed precision
self._device_type = torch.device(device).type
self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision)

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -211,8 +219,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens
return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy")

# sample stochastic actions
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
self._current_log_prob = log_prob
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
self._current_log_prob = log_prob

return actions, log_prob, outputs

Expand Down Expand Up @@ -261,8 +270,9 @@ def record_transition(
rewards = self._rewards_shaper(rewards, timestep, timesteps)

# compute values
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self._time_limit_bootstrap:
Expand Down Expand Up @@ -375,13 +385,13 @@ def compute_gae(
return returns, advantages

# compute returns and advantages
with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
self.value.train(False)
last_values, _, _ = self.value.act(
{"states": self._state_preprocessor(self._current_next_states.float())}, role="value"
)
self.value.train(True)
last_values = self._value_preprocessor(last_values, inverse=True)
last_values = self._value_preprocessor(last_values, inverse=True)

values = self.memory.get_tensor_by_name("values")
returns, advantages = compute_gae(
Expand Down Expand Up @@ -409,49 +419,56 @@ def compute_gae(
# mini-batches loop
for sampled_states, sampled_actions, sampled_log_prob, sampled_returns, sampled_advantages in sampled_batches:

sampled_states = self._state_preprocessor(sampled_states, train=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):

_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)
sampled_states = self._state_preprocessor(sampled_states, train=True)

# compute approximate KL divergence for KLAdaptive learning rate scheduler
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)
_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0
# compute approximate KL divergence for KLAdaptive learning rate scheduler
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()
# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")
# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")

value_loss = F.mse_loss(sampled_returns, predicted_values)
value_loss = F.mse_loss(sampled_returns, predicted_values)

# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
self.scaler.scale(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
self.scaler.unscale_(self.optimizer)
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(
itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip
)
self.optimizer.step()

self.scaler.step(self.optimizer)
self.scaler.update()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand Down
77 changes: 48 additions & 29 deletions skrl/agents/torch/a2c/a2c_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)

"mixed_precision": False, # enable automatic mixed precision for higher performance

"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
Expand Down Expand Up @@ -142,6 +144,12 @@ def __init__(
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

self._mixed_precision = self.cfg["mixed_precision"]

# set up automatic mixed precision
self._device_type = torch.device(device).type
self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision)

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -248,8 +256,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens
return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy")

# sample stochastic actions
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy")
self._current_log_prob = log_prob
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
actions, log_prob, outputs = self.policy.act(
{"states": self._state_preprocessor(states), **rnn}, role="policy"
)
self._current_log_prob = log_prob

if self._rnn:
self._rnn_final_states["policy"] = outputs.get("rnn", [])
Expand Down Expand Up @@ -301,9 +312,10 @@ def record_transition(
rewards = self._rewards_shaper(rewards, timestep, timesteps)

# compute values
rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {}
values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value")
values = self._value_preprocessor(values, inverse=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {}
values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self._time_limit_bootstrap:
Expand Down Expand Up @@ -446,14 +458,14 @@ def compute_gae(
return returns, advantages

# compute returns and advantages
with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
self.value.train(False)
rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {}
last_values, _, _ = self.value.act(
{"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value"
)
self.value.train(True)
last_values = self._value_preprocessor(last_values, inverse=True)
last_values = self._value_preprocessor(last_values, inverse=True)

values = self.memory.get_tensor_by_name("values")
returns, advantages = compute_gae(
Expand Down Expand Up @@ -523,48 +535,55 @@ def compute_gae(
"terminated": sampled_dones,
}

sampled_states = self._state_preprocessor(sampled_states, train=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):

_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy"
)
sampled_states = self._state_preprocessor(sampled_states, train=True)

# compute approximate KL divergence for KLAdaptive learning rate scheduler
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)
_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy"
)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0
# compute approximate KL divergence for KLAdaptive learning rate scheduler
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()
# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value")
# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value")

value_loss = F.mse_loss(sampled_returns, predicted_values)
value_loss = F.mse_loss(sampled_returns, predicted_values)

# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
self.scaler.scale(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
self.scaler.unscale_(self.optimizer)
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(
itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip
)
self.optimizer.step()

self.scaler.step(self.optimizer)
self.scaler.update()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand Down

0 comments on commit ea9ce13

Please sign in to comment.