Skip to content

Commit

Permalink
Add AMP mixed precision support in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 15, 2024
1 parent ea9ce13 commit f0997e0
Showing 1 changed file with 115 additions and 97 deletions.
212 changes: 115 additions & 97 deletions skrl/agents/torch/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)

"mixed_precision": False, # enable automatic mixed precision for higher performance

"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
Expand Down Expand Up @@ -204,6 +206,12 @@ def __init__(
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

self._mixed_precision = self.cfg["mixed_precision"]

# set up automatic mixed precision
self._device_type = torch.device(device).type
self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision)

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None and self.discriminator is not None:
self.optimizer = torch.optim.Adam(
Expand Down Expand Up @@ -308,8 +316,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens
return self.policy.random_act({"states": states}, role="policy")

# sample stochastic actions
actions, log_prob, outputs = self.policy.act({"states": states}, role="policy")
self._current_log_prob = log_prob
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
actions, log_prob, outputs = self.policy.act({"states": states}, role="policy")
self._current_log_prob = log_prob

return actions, log_prob, outputs

Expand Down Expand Up @@ -361,18 +370,20 @@ def record_transition(
if self._rewards_shaper is not None:
rewards = self._rewards_shaper(rewards, timestep, timesteps)

with torch.no_grad():
# compute values
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self._time_limit_bootstrap:
rewards += self._discount_factor * values * truncated

with torch.no_grad():
# compute next values
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
next_values, _, _ = self.value.act({"states": self._state_preprocessor(next_states)}, role="value")
next_values = self._value_preprocessor(next_values, inverse=True)
next_values *= infos["terminate"].view(-1, 1).logical_not()
next_values = self._value_preprocessor(next_values, inverse=True)
next_values *= infos["terminate"].view(-1, 1).logical_not()

self.memory.add_samples(
states=states,
Expand Down Expand Up @@ -490,7 +501,7 @@ def compute_gae(
rewards = self.memory.get_tensor_by_name("rewards")
amp_states = self.memory.get_tensor_by_name("amp_states")

with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
amp_logits, _, _ = self.discriminator.act(
{"states": self._amp_state_preprocessor(amp_states)}, role="discriminator"
)
Expand Down Expand Up @@ -554,120 +565,127 @@ def compute_gae(
_,
) in enumerate(sampled_batches):

sampled_states = self._state_preprocessor(sampled_states, train=True)

_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
ratio = torch.exp(next_log_prob - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(
ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip
)

policy_loss = -torch.min(surrogate, surrogate_clipped).mean()
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")
sampled_states = self._state_preprocessor(sampled_states, train=True)

if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(
predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip
_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)

# compute discriminator loss
if self._discriminator_batch_size:
sampled_amp_states = self._amp_state_preprocessor(
sampled_amp_states[0 : self._discriminator_batch_size], train=True
# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
ratio = torch.exp(next_log_prob - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(
ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip
)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[batch_index][0][0 : self._discriminator_batch_size], train=True
)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[batch_index][0][0 : self._discriminator_batch_size], train=True
)
else:
sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states, train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[batch_index][0], train=True

policy_loss = -torch.min(surrogate, surrogate_clipped).mean()

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")

if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(
predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip
)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)

# compute discriminator loss
if self._discriminator_batch_size:
sampled_amp_states = self._amp_state_preprocessor(
sampled_amp_states[0 : self._discriminator_batch_size], train=True
)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[batch_index][0][0 : self._discriminator_batch_size], train=True
)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[batch_index][0][0 : self._discriminator_batch_size], train=True
)
else:
sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states, train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[batch_index][0], train=True
)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[batch_index][0], train=True
)

sampled_amp_motion_states.requires_grad_(True)
amp_logits, _, _ = self.discriminator.act({"states": sampled_amp_states}, role="discriminator")
amp_replay_logits, _, _ = self.discriminator.act(
{"states": sampled_amp_replay_states}, role="discriminator"
)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[batch_index][0], train=True
amp_motion_logits, _, _ = self.discriminator.act(
{"states": sampled_amp_motion_states}, role="discriminator"
)

sampled_amp_motion_states.requires_grad_(True)
amp_logits, _, _ = self.discriminator.act({"states": sampled_amp_states}, role="discriminator")
amp_replay_logits, _, _ = self.discriminator.act(
{"states": sampled_amp_replay_states}, role="discriminator"
)
amp_motion_logits, _, _ = self.discriminator.act(
{"states": sampled_amp_motion_states}, role="discriminator"
)

amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)
amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)

# discriminator prediction loss
discriminator_loss = 0.5 * (
nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits))
+ torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits))
)

# discriminator logit regularization
if self._discriminator_logit_regularization_scale:
logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
torch.square(logit_weights)
# discriminator prediction loss
discriminator_loss = 0.5 * (
nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits))
+ torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits))
)

# discriminator gradient penalty
if self._discriminator_gradient_penalty_scale:
amp_motion_gradient = torch.autograd.grad(
amp_motion_logits,
sampled_amp_motion_states,
grad_outputs=torch.ones_like(amp_motion_logits),
create_graph=True,
retain_graph=True,
only_inputs=True,
)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty

# discriminator weight decay
if self._discriminator_weight_decay_scale:
weights = [
torch.flatten(module.weight)
for module in self.discriminator.modules()
if isinstance(module, torch.nn.Linear)
]
weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
discriminator_loss += self._discriminator_weight_decay_scale * weight_decay

discriminator_loss *= self._discriminator_loss_scale
# discriminator logit regularization
if self._discriminator_logit_regularization_scale:
logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
torch.square(logit_weights)
)

# discriminator gradient penalty
if self._discriminator_gradient_penalty_scale:
amp_motion_gradient = torch.autograd.grad(
amp_motion_logits,
sampled_amp_motion_states,
grad_outputs=torch.ones_like(amp_motion_logits),
create_graph=True,
retain_graph=True,
only_inputs=True,
)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty

# discriminator weight decay
if self._discriminator_weight_decay_scale:
weights = [
torch.flatten(module.weight)
for module in self.discriminator.modules()
if isinstance(module, torch.nn.Linear)
]
weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
discriminator_loss += self._discriminator_weight_decay_scale * weight_decay

discriminator_loss *= self._discriminator_loss_scale

# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss + discriminator_loss).backward()
self.scaler.scale(policy_loss + entropy_loss + value_loss + discriminator_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
self.value.reduce_parameters()
self.discriminator.reduce_parameters()

if self._grad_norm_clip > 0:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(
itertools.chain(
self.policy.parameters(), self.value.parameters(), self.discriminator.parameters()
),
self._grad_norm_clip,
)
self.optimizer.step()

self.scaler.step(self.optimizer)
self.scaler.update()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand Down

0 comments on commit f0997e0

Please sign in to comment.