Skip to content

Commit

Permalink
Implement new approach with mark_dirty()
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Nov 4, 2024
1 parent fed2a10 commit 5f8913d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
51 changes: 34 additions & 17 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def liger_cross_entropy_kernel(
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
n_cols,
n_non_ignore,
ignore_index,
Expand Down Expand Up @@ -49,6 +51,7 @@ def liger_cross_entropy_kernel(

# 2. locate the start index
X_ptr += program_id * X_stride
dX_ptr += program_id * dX_stride

if y == ignore_index:
# set all X_ptr as 0
Expand Down Expand Up @@ -106,15 +109,15 @@ def liger_cross_entropy_kernel(

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
dX_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if reduction == "mean":
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
dX_block = (tl.exp(dX_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
dX_block = tl.exp(dX_block - m) / d - eps

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
tl.store(dX_ptr + X_offsets, dX_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
Expand Down Expand Up @@ -145,14 +148,14 @@ def liger_cross_entropy_kernel(
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
dX_y = tl.load(dX_ptr + y)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
dX_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
dX_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
tl.store(dX_ptr + y, dX_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -161,7 +164,9 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
def cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction, inplace
):
BT, V = _input.shape
n_rows = BT

Expand All @@ -178,10 +183,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
if target.stride(-1) != 1:
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
# Explicitly declare an in-place operation is performed by adding a numerical value of 0 to the input in-place
# https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
_input.add_(0)
dX = _input if inplace else torch.empty_like(_input)

liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
Expand All @@ -190,6 +192,8 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
dX_ptr=dX,
dX_stride=dX.stride(-2),
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
Expand Down Expand Up @@ -237,7 +241,13 @@ class LigerCrossEntropyFunction(torch.autograd.Function):

@staticmethod
def forward(
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
ctx,
_input,
target,
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
inplace=True,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -254,16 +264,21 @@ def forward(
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction
_input, target, ignore_index, label_smoothing, reduction, inplace
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
return loss

print(f"{inplace=}")
if inplace:
ctx.mark_dirty(_input)
ctx.mark_non_differentiable(_input)
return loss, _input

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output, grad_output2):
"""
The backward pass of the Liger Cross Entropy loss.
Expand All @@ -274,6 +289,7 @@ def backward(ctx, grad_output):
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
del grad_output2
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
Expand All @@ -282,4 +298,5 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)
2 changes: 2 additions & 0 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def fused_linear_cross_entropy_forward(
Y_stride=target_chunk.stride(-1), # always 1
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
dX_ptr=logits_chunk,
dX_stride=logits_chunk.stride(-2),
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
Expand Down
29 changes: 19 additions & 10 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
from torch.nn import CrossEntropyLoss
import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction


class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="mean"):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
assert self.reduction in {
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing, self.reduction
def forward(self, _input, target, inplace):
loss, _ = LigerCrossEntropyFunction.apply(
_input,
target,
ignore_index=self.ignore_index,
label_smoothing=self.label_smoothing,
reduction=self.reduction,
inplace=inplace,
)
return loss
4 changes: 2 additions & 2 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

y1 = liger_cross_entropy(x1, target, 0)
y2 = LigerCrossEntropyFunction.apply(x2, target, 0)
y1, _ = liger_cross_entropy(x1, target, 0)
y2, _ = LigerCrossEntropyFunction.apply(x2, target, 0)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

Expand Down

0 comments on commit 5f8913d

Please sign in to comment.