Skip to content

Commit

Permalink
Support FusedLinearCrossEntropy for Gemma2 (#320)
Browse files Browse the repository at this point in the history
## Summary
Resolves #127.

Fuse softcapping into cross_entropy kernel, so it can be called by fused
linear cross entropy function.

## Testing Done
Current monkey patch for Gemma2 can't pass covergence test without flce
either. The test is commented out for now.

- Hardware Type: 
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Byron Hsu <[email protected]>
Co-authored-by: Shao Tang <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent e7c55da commit 2d3eb94
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 67 deletions.
63 changes: 46 additions & 17 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import operator
from typing import Optional

import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import element_mul_kernel, is_hip
from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip

if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh

_TRUE = tl.constexpr(1)
_FALSE = tl.constexpr(0)
Expand All @@ -23,8 +36,10 @@ def liger_cross_entropy_kernel(
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Expand All @@ -45,7 +60,9 @@ def liger_cross_entropy_kernel(
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

# https://github.com/triton-lang/triton/issues/1058
Expand Down Expand Up @@ -78,6 +95,8 @@ def liger_cross_entropy_kernel(
ori_X_y = tl.load(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
Expand All @@ -89,6 +108,8 @@ def liger_cross_entropy_kernel(
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
Expand Down Expand Up @@ -122,15 +143,24 @@ def liger_cross_entropy_kernel(
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
# chain rule
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

Expand All @@ -151,7 +181,7 @@ def liger_cross_entropy_kernel(
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
Expand All @@ -168,17 +198,9 @@ def liger_cross_entropy_kernel(
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
tl.store(z_loss_ptr, z_loss)
tl.store(X_ptr + y, X_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -200,6 +222,7 @@ def cross_entropy_forward(
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
):
if not isinstance(return_z_loss, int):
Expand Down Expand Up @@ -247,8 +270,10 @@ def cross_entropy_forward(
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32 if not is_hip() else 16,
Expand Down Expand Up @@ -296,13 +321,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
_input,
target,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
return_z_loss=False,
_input: torch.Tensor,
target: torch.Tensor,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -315,6 +341,7 @@ def forward(
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
Returns:
Expand All @@ -327,6 +354,7 @@ def forward(
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
)
# TODO: investigation
Expand Down Expand Up @@ -362,4 +390,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
7 changes: 6 additions & 1 deletion src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def fused_linear_cross_entropy_forward(
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
dtype = _input.dtype
device = _input.device
Expand Down Expand Up @@ -95,7 +96,9 @@ def fused_linear_cross_entropy_forward(
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
Expand Down Expand Up @@ -207,6 +210,7 @@ def forward(
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand Down Expand Up @@ -234,6 +238,7 @@ def forward(
lse_square_scale,
label_smoothing,
reduction,
softcap,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -250,4 +255,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
44 changes: 27 additions & 17 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,51 @@
import torch.nn as nn
from typing import Optional

import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction


class LigerCrossEntropyLoss(nn.Module):
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
return_z_loss=False,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss

assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
assert self.reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"

def forward(self, _input, target):
def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
_input,
target,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
self.reduction,
self.softcap,
self.return_z_loss,
)
if not self.return_z_loss:
Expand Down
33 changes: 23 additions & 10 deletions src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
import torch.nn as nn
from typing import Optional

import torch

from liger_kernel.ops.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyFunction,
)


class LigerFusedLinearCrossEntropyLoss(nn.Module):
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
lse_square_scale=0.0,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
self.reduction = reduction
self.lse_square_scale = lse_square_scale
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
self.softcap = softcap

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCrossEntropyFunction.apply(
Expand All @@ -32,4 +44,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.lse_square_scale,
self.label_smoothing,
self.reduction,
self.softcap,
)
Loading

0 comments on commit 2d3eb94

Please sign in to comment.