diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 455abc67..8cc116a0 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -1,8 +1,21 @@ +import operator +from typing import Optional + import torch import triton import triton.language as tl -from liger_kernel.ops.utils import element_mul_kernel, is_hip +from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh _TRUE = tl.constexpr(1) _FALSE = tl.constexpr(0) @@ -23,8 +36,10 @@ def liger_cross_entropy_kernel( lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, ): """ This kernel computes both cross entropy loss and the gradient of the input. @@ -45,7 +60,9 @@ def liger_cross_entropy_kernel( lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ # https://github.com/triton-lang/triton/issues/1058 @@ -78,6 +95,8 @@ def liger_cross_entropy_kernel( ori_X_y = tl.load( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) # Label smoothing is a general case of normal cross entropy # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 @@ -89,6 +108,8 @@ def liger_cross_entropy_kernel( X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow @@ -122,15 +143,24 @@ def liger_cross_entropy_kernel( X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate # softmax(x_i) X_block = tl.exp(X_block - m) / d # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) X_block += 2 * lse_square_scale * lse * X_block # smoothing term X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) # reduction scale if reduction == "mean": X_block = X_block / (n_non_ignore) + # chain rule + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) @@ -151,7 +181,7 @@ def liger_cross_entropy_kernel( # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 @@ -168,17 +198,9 @@ def liger_cross_entropy_kernel( z_loss = z_loss / n_non_ignore loss = loss / n_non_ignore - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - X_y = tl.load(X_ptr + y) - if reduction == "mean": - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: tl.store(z_loss_ptr, z_loss) - tl.store(X_ptr + y, X_y) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 @@ -200,6 +222,7 @@ def cross_entropy_forward( lse_square_scale, label_smoothing, reduction, + softcap, return_z_loss, ): if not isinstance(return_z_loss, int): @@ -247,8 +270,10 @@ def cross_entropy_forward( lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, + softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, + HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps num_warps=32 if not is_hip() else 16, @@ -296,13 +321,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( ctx, - _input, - target, - ignore_index=-100, - lse_square_scale=0.0, - label_smoothing=0.0, - reduction="mean", - return_z_loss=False, + _input: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, ): """ The forward pass of the Liger Cross Entropy loss. @@ -315,6 +341,7 @@ def forward( lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` Returns: @@ -327,6 +354,7 @@ def forward( lse_square_scale, label_smoothing, reduction, + softcap, return_z_loss, ) # TODO: investigation @@ -362,4 +390,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 34016ee4..f053b918 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -24,6 +24,7 @@ def fused_linear_cross_entropy_forward( lse_square_scale=0.0, label_smoothing=0.0, reduction="mean", + softcap=None, ): dtype = _input.dtype device = _input.device @@ -95,7 +96,9 @@ def fused_linear_cross_entropy_forward( lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, + softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False + HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) @@ -207,6 +210,7 @@ def forward( lse_square_scale=0.0, label_smoothing=0.0, reduction="mean", + softcap=None, ): """ Fusing the last linear layer with cross-entropy loss @@ -234,6 +238,7 @@ def forward( lse_square_scale, label_smoothing, reduction, + softcap, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -250,4 +255,4 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None, None) + return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index f612f6f4..7bd27edd 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -1,34 +1,43 @@ -import torch.nn as nn +from typing import Optional + +import torch from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction -class LigerCrossEntropyLoss(nn.Module): +class LigerCrossEntropyLoss(torch.nn.Module): def __init__( self, - ignore_index=-100, - lse_square_scale=0.0, - label_smoothing=0.0, - reduction="mean", - return_z_loss=False, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, ): super().__init__() + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert reduction in { + "mean", + "sum", + "none", + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" + assert ( + softcap is None or softcap > 0 + ), f"softcap must greater than 0.0 or None. Got: {softcap}" self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing self.reduction = reduction + self.softcap = softcap self.return_z_loss = return_z_loss - assert (self.label_smoothing >= 0) and ( - self.label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" - assert self.reduction in { - "mean", - "sum", - "none", - }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}" - - def forward(self, _input, target): + def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss = LigerCrossEntropyFunction.apply( _input, target, @@ -36,6 +45,7 @@ def forward(self, _input, target): self.lse_square_scale, self.label_smoothing, self.reduction, + self.softcap, self.return_z_loss, ) if not self.return_z_loss: diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index fa6b37a9..7df79d30 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -1,26 +1,38 @@ -import torch.nn as nn +from typing import Optional + +import torch from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) -class LigerFusedLinearCrossEntropyLoss(nn.Module): +class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): def __init__( self, - ignore_index=-100, - label_smoothing=0.0, - reduction="mean", - lse_square_scale=0.0, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, ): super().__init__() + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert reduction in { + "mean", + "sum", + "none", + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" + assert ( + softcap is None or softcap > 0 + ), f"softcap must greater than 0.0 or None. Got: {softcap}" self.ignore_index = ignore_index + self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing self.reduction = reduction - self.lse_square_scale = lse_square_scale - assert (self.label_smoothing >= 0) and ( - self.label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" + self.softcap = softcap def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCrossEntropyFunction.apply( @@ -32,4 +44,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.lse_square_scale, self.label_smoothing, self.reduction, + self.softcap, ) diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py new file mode 100644 index 00000000..8ce5aa69 --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -0,0 +1,277 @@ +import logging +from typing import Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import HybridCache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.gemma2.modeling_gemma2 import ( + _CONFIG_FOR_DOC, + GEMMA2_INPUTS_DOCSTRING, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + +logger = logging.getLogger(__name__) + + +def lce_forward_deprecated( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten + + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss( + softcap=self.config.final_logit_softcapping + ) + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + + else: + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss( + softcap=self.config.final_logit_softcapping, + reduction=reduction, + ) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index ca199ad8..fb1a8db9 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -14,6 +14,10 @@ from liger_kernel.transformers.model.gemma import ( lce_forward_deprecated as gemma_lce_forward_deprecated, ) +from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward +from liger_kernel.transformers.model.gemma2 import ( + lce_forward_deprecated as gemma2_lce_forward_deprected, +) from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.llama import ( lce_forward_deprecated as llama_lce_forward_deprecated, @@ -252,7 +256,7 @@ def apply_liger_kernel_to_mistral( Apply Liger kernels to replace original implementation in HuggingFace Mistral models Args: - rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. fused_linear_cross_entropy (bool): Whether to apply Liger's fused linear cross entropy loss. Default is True. @@ -445,7 +449,8 @@ def apply_liger_kernel_to_gemma( def apply_liger_kernel_to_gemma2( rope: bool = True, - cross_entropy: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, rms_norm: bool = True, geglu: bool = True, model: PreTrainedModel = None, @@ -456,12 +461,19 @@ def apply_liger_kernel_to_gemma2( Args: rope (bool): Whether to apply Liger's rotary position embedding. Default is True. - cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.gemma2 import modeling_gemma2 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model @@ -479,6 +491,12 @@ def apply_liger_kernel_to_gemma2( modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 if cross_entropy: modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected if geglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 72be62c0..e4c1b552 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -410,13 +410,8 @@ def run_mini_model( else: kwargs["swiglu"] = True - model_support_flce = "gemma2" not in model_name - - if model_support_flce: - kwargs["fused_linear_cross_entropy"] = True - kwargs["cross_entropy"] = False - else: - kwargs["cross_entropy"] = True + kwargs["fused_linear_cross_entropy"] = True + kwargs["cross_entropy"] = False MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 3ca0e7fc..a505e6fc 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -172,6 +172,29 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_softcap_once( + target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol +): + + torch_ce = CrossEntropyLoss(reduction=reduction) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # downcasting to original dtype + output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + + def _test_correctness_with_z_loss_once( target_ce, B, @@ -196,7 +219,6 @@ def _test_correctness_with_z_loss_once( _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - if return_z_loss: output, z_output = torch_ce(_input, target) output2, z_output2 = target_ce(_input2, target) @@ -271,11 +293,6 @@ def _test_correctness_with_z_loss_with_other_params_once( output.backward() output2.backward() - print(_input.grad) - print(_input2.grad) - - print(f"{(_input.grad - _input2.grad).sum()=}") - assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -303,7 +320,15 @@ def _test_correctness_not_last_layer_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): +def _test_correctness_functional( + B, + T, + V, + scalar, + dtype, + atol, + rtol, +): _input = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -312,8 +337,10 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - y1, y1_z = liger_cross_entropy(x1, target, 0, 1e-4, 0.1, "mean", True) - y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, 0, 1e-4, 0.1, "mean", True) + y1, y1_z = liger_cross_entropy(x1, target, 0, 1e-4, 0.1, "mean", 30.0, True) + y2, y2_z = LigerCrossEntropyFunction.apply( + x2, target, 0, 1e-4, 0.1, "mean", 30.0, True + ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) @@ -478,6 +505,39 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( ) +@pytest.mark.parametrize( + "B, T, V, softcap", + [ + (2, 4096, 32000, 30.0), # llama2, mistral + # weird shapes + (3, 423, 32000, 30.0), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_softcap_once( + B, T, V, softcap, reduction, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss(softcap=softcap, reduction=reduction) + _test_correctness_with_softcap_once( + liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [ diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 2be9c9d1..881330c5 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -1,5 +1,6 @@ from test.transformers.test_cross_entropy import CrossEntropyWithZLoss from test.utils import assert_verbose_allclose, set_seed +from typing import Optional import pytest import torch @@ -41,6 +42,7 @@ def __init__( lse_square_scale: float = 0.0, label_smoothing: float = 0.0, reduction: str = "mean", + softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -52,9 +54,12 @@ def __init__( label_smoothing=label_smoothing, reduction=reduction, ) + self.softcap = softcap def forward(self, x, y): logits = self.lin(x).to(torch.float32) + if self.softcap is not None and self.softcap != 0.0: + logits = self.softcap * torch.tanh(logits / self.softcap) return self.ce_loss(logits, y) @@ -69,6 +74,7 @@ def __init__( lse_square_scale: float = 0.0, label_smoothing: float = 0.0, reduction: str = "mean", + softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -79,6 +85,7 @@ def __init__( lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, + softcap=softcap, ) def forward(self, x, y): @@ -108,10 +115,15 @@ def forward(self, x, y): ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize( - "label_smoothing, ignore_index, lse_square_scale", + "label_smoothing, ignore_index, lse_square_scale, softcap", [ - (0, -100, 0), - (0.1, 42, 1e-4), # Pass non-default values once to ensure all params work along + (0, -100, 0, None), + ( + 0.1, + 42, + 1e-4, + 30.0, + ), # Pass non-default values once to ensure all params work along ], ) def test_correctness( @@ -126,6 +138,7 @@ def test_correctness( label_smoothing, ignore_index, reduction, + softcap, atol, rtol, ): @@ -138,6 +151,7 @@ def test_correctness( label_smoothing=label_smoothing, ignore_index=ignore_index, reduction=reduction, + softcap=softcap, dtype=dtype, ).to(device) liger_lm_head_ce = LigerLMHeadCE( @@ -148,6 +162,7 @@ def test_correctness( label_smoothing=label_smoothing, ignore_index=ignore_index, reduction=reduction, + softcap=softcap, dtype=dtype, ).to(device)