Skip to content

Commit

Permalink
correct metrics (#1514)
Browse files Browse the repository at this point in the history
Co-authored-by: Clara Luise Pohland <[email protected]>
  • Loading branch information
claralp and Clara Luise Pohland authored Apr 8, 2024
1 parent 4dca169 commit 85f5fd2
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,12 +973,12 @@ def kto_loss(
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
"""
KL = (policy_KL_logps - reference_KL_logps).mean().detach()
KL = self.accelerator.gather(KL).mean().clamp(min=0)
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
kl = self.accelerator.gather(kl).mean().clamp(min=0)

if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - KL))
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
Expand All @@ -987,7 +987,7 @@ def kto_loss(

if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
Expand All @@ -999,13 +999,12 @@ def kto_loss(
0,
)

return losses, chosen_rewards, rejected_rewards, KL
return losses, chosen_rewards, rejected_rewards, kl

def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
Expand Down Expand Up @@ -1047,7 +1046,7 @@ def get_batch_loss_metrics(
reference_KL_logps,
) = self.forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards, KL = self.kto_loss(
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
Expand All @@ -1059,33 +1058,20 @@ def get_batch_loss_metrics(
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen)
all_num_rejected = self.accelerator.gather(num_rejected)

prefix = "eval_" if train_eval == "eval" else ""

if all_num_chosen.sum().item() > 0:
metrics[f"{prefix}rewards/chosen"] = (
(self.accelerator.gather(chosen_rewards.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()
metrics[f"{prefix}logps/chosen"] = (
(self.accelerator.gather(policy_chosen_logps.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()

if all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/rejected"] = (
(self.accelerator.gather(rejected_rewards.mean()) * all_num_rejected).nansum() / all_num_rejected.sum()
).item()
metrics[f"{prefix}logps/rejected"] = (
(self.accelerator.gather(policy_rejected_logps.mean()) * all_num_rejected).nansum()
/ all_num_rejected.sum()
).item()

metrics[f"{prefix}kl"] = KL.item()
if all_num_chosen.sum().item() > 0 and all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/margins"] = (
metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected

metrics["kl"] = kl.item()

return losses.nanmean(), metrics

Expand All @@ -1103,7 +1089,7 @@ def compute_loss(
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
Expand All @@ -1117,10 +1103,7 @@ def compute_loss(

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
if isinstance(value, list):
self._stored_metrics[train_eval][key].extend(value)
else:
self._stored_metrics[train_eval][key].append(value)
self._stored_metrics[train_eval][key].append(value)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
Expand Down Expand Up @@ -1193,7 +1176,7 @@ def prediction_step(

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# force log the metrics
if self.accelerator.is_main_process:
Expand All @@ -1204,8 +1187,8 @@ def prediction_step(

# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
"eval_logits/chosen": metrics["logits/chosen"],
"eval_logits/rejected": metrics["logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
Expand Down Expand Up @@ -1279,9 +1262,26 @@ def log(self, logs: Dict[str, float]) -> None:
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
logs[f"{train_eval}/rewards/{split}"] = (
torch.Tensor(self._stored_metrics[train_eval][f"rewards/{split}_sum"]).sum().item() / count_sum
)
logs[f"{train_eval}/logps/{split}"] = (
torch.Tensor(self._stored_metrics[train_eval][f"logps/{split}_sum"]).sum().item() / count_sum
)
for key in [f"count/{split}", f"rewards/{split}_sum", f"logps/{split}_sum"]:
del self._stored_metrics[train_eval][key]
# calculate reward margin
if f"{train_eval}/rewards/chosen" in logs and f"{train_eval}/rewards/rejected" in logs:
logs[f"{train_eval}/rewards/margins"] = (
logs[f"{train_eval}/rewards/chosen"] - logs[f"{train_eval}/rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
logs[f"{train_eval}/{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

Expand Down

0 comments on commit 85f5fd2

Please sign in to comment.