Skip to content

Commit

Permalink
Add GraniteRMSNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Aug 28, 2024
1 parent 5c1027b commit 5726650
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
31 changes: 28 additions & 3 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -99,6 +100,30 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Granite
class GraniteRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
GraniteRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


ALL_LAYERNORM_LAYERS.append(GraniteRMSNorm)


class GraniteRotaryEmbedding(nn.Module):
def __init__(self, config: GraniteConfig):
super().__init__()
Expand Down Expand Up @@ -534,8 +559,8 @@ def __init__(self, config: GraniteConfig, layer_idx: int):
self.self_attn = GRANITE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = GraniteMLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.residual_multiplier = config.residual_multiplier

Expand Down Expand Up @@ -749,7 +774,7 @@ def __init__(self, config: GraniteConfig):
self.layers = nn.ModuleList(
[GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

self.embedding_multiplier = config.embedding_multiplier
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .utils import is_torch_xla_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm]
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.get_logger(__name__)

Expand Down

0 comments on commit 5726650

Please sign in to comment.