Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shisahni_LinkedIn committed Nov 8, 2024
1 parent 804a1cc commit 0a081be
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
6 changes: 4 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100, beta=0.1,
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
bias (torch.Tensor, optional): Bias tensor. Shape: (hidden_size,).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
ignore_index (int): Index to ignore for loss computation.
compiled (bool): Whether to use compiled mode for chunk accumulation.
compiled (bool): Whether to use torch compile for chunk accumulation.
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = 1

def _compute_orpo_loss(input_chunk, weight, target_chunk, bias=None):
Expand Down
52 changes: 31 additions & 21 deletions test/chunked_loss/test_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import torch.nn as nn
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
import pytest
from test.utils import assert_verbose_allclose, set_seed

# set random seed globally
set_seed()


class HF_ORPO_Loss:
Expand Down Expand Up @@ -138,15 +142,15 @@ def get_batch_loss_metrics(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss #- losses.mean()
loss = policy_nll_loss - losses.mean()
return loss


@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(4, 47, 31, 123), # random shape
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
Expand All @@ -159,33 +163,39 @@ def get_batch_loss_metrics(
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta):
# Define input tensors
_tensor = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
_input1 = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)
B = 2 * B # orpo loss requires B to be even

_input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B, T,), device="cuda", dtype=torch.long)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
).item()
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
]
target.view(-1)[indices_to_assign] = ignore_index

weight = torch.randn(V, H, device="cuda", dtype=dtype)
bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None
# Initalize HF ORPO Loss
hf_orpo_loss = HF_ORPO_Loss(ignore_index=ignore_index, beta=beta)
# Compute the ORPO loss
loss1 = hf_orpo_loss.get_batch_loss_metrics(_input1, weight, target, bias)
# Compute the ORPO loss using the LigerFusedLinearORPOFunction
loss2 = LigerFusedLinearORPOFunction.apply(_input2, weight, target, bias, ignore_index, beta, True)
# Compare the two losses
assert torch.allclose(loss1, loss2, atol=atol, rtol=rtol)
# Compute the gradients
_weight = torch.randn(V, H, device="cuda", dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)

_bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None

loss1 = HF_ORPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics(input1, weight1, target, bias1)
loss2 = LigerFusedLinearORPOFunction.apply(input2, weight2, target, bias2, ignore_index, beta, True)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

loss1.backward()
loss2.backward()
# Compare the gradients
assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)

0 comments on commit 0a081be

Please sign in to comment.