Skip to content

Commit

Permalink
Make nn configs more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 22, 2024
1 parent 0bcc840 commit 98960f7
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 164 deletions.
50 changes: 26 additions & 24 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ class AttentionConfig(Config):
"""
n_heads: int = 16
n_kv_heads: Optional[int] = None
bias: bool = True
bias: Optional[bool] = None
rope: Optional[RoPEConfig] = None
clip_qkv: Optional[float] = None
qk_norm: Optional[LayerNormConfig] = None
dropout: float = 0.0
dropout: Optional[float] = None
use_flash: Optional[bool] = None
dtype: DType = DType.float32

Expand All @@ -72,29 +72,25 @@ def build(
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs["dtype"] = kwargs["dtype"].as_pt()
kwargs.update(
dict(
d_model=d_model,
init_device=init_device,
cache=cache,
)
dtype=kwargs.pop("dtype").as_pt(),
d_model=d_model,
init_device=init_device,
cache=cache,
)

if self.name == "default":
return Attention(**kwargs)
elif self.name == "fused":
kwargs.pop("use_flash", None)
return FusedAttention(**kwargs)
elif self.name == "normalized":
bias = kwargs.pop("bias")
if bias:
raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' attention")
if kwargs.pop("dropout") > 0.0:
raise OLMoConfigurationError(f"'dropout' is invalid for '{self.name}' attention")
return NormalizedAttention(**kwargs)
else:
raise NotImplementedError(self.name)
try:
if self.name == "default":
return Attention(**kwargs)
elif self.name == "fused":
kwargs.pop("use_flash", None)
return FusedAttention(**kwargs)
elif self.name == "normalized":
return NormalizedAttention(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e


class Attention(nn.Module):
Expand Down Expand Up @@ -312,6 +308,7 @@ def __init__(
n_heads: int,
n_kv_heads: Optional[int] = None,
rope: Optional[RoPEConfig] = None,
qk_norm: Optional[LayerNormConfig] = None,
use_flash: bool = False,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
Expand All @@ -322,6 +319,7 @@ def __init__(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
rope=rope,
qk_norm=qk_norm,
use_flash=use_flash,
bias=False,
dtype=dtype,
Expand Down Expand Up @@ -364,11 +362,15 @@ def forward(
# (batch_size, seq_len, n_kv_heads * head_dim)
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)

if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)

sq = (self.sq * (self.sq_init_value / self.sq_init_scaling)).view(1, 1, -1)
q = sq * l2_normalize(q)
q = sq * q

sk = (self.sk * (self.sk_init_value / self.sk_init_scaling)).view(1, 1, -1)
k = sk * l2_normalize(k)
k = sk * k

# shape: (batch_size, seq_len, n_heads, head_dim)
q = q.view(B, T, self.n_heads, self.head_dim)
Expand Down
33 changes: 15 additions & 18 deletions src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -38,27 +39,23 @@ class FeedForwardConfig(Config):

hidden_size: int
name: FeedForwardType = FeedForwardType.default
bias: bool = True
bias: Optional[bool] = None
dtype: DType = DType.float32

def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward":
if self.name == FeedForwardType.default:
return FeedForward(
d_model=d_model,
hidden_size=self.hidden_size,
bias=self.bias,
dtype=self.dtype.as_pt(),
init_device=init_device,
)
else:
if self.bias:
raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' feed-forward")
return NormalizedFeedForward(
d_model=d_model,
hidden_size=self.hidden_size,
dtype=self.dtype.as_pt(),
init_device=init_device,
)
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt())

try:
if self.name == FeedForwardType.default:
return FeedForward(**kwargs)
elif self.name == FeedForwardType.normalized:
return NormalizedFeedForward(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e


class FeedForward(nn.Module):
Expand Down
66 changes: 43 additions & 23 deletions src/olmo_core/nn/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..config import Config, DType, StrEnum
from ..exceptions import OLMoConfigurationError
from .functional import l2_normalize

__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm"]
__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm", "L2Norm"]


class LayerNormType(StrEnum):
Expand All @@ -21,6 +24,7 @@ class LayerNormType(StrEnum):
default = "default"
rms = "rms"
fused_rms = "fused_rms"
l2_norm = "l2_norm"


@dataclass
Expand All @@ -37,11 +41,11 @@ class LayerNormConfig(Config):
- "rms" ➡️ :class:`RMSNorm`
- "fused_rms" ➡️ :class:`FusedRMSNorm`
"""
eps: float = 1e-5
elementwise_affine: bool = True
bias: bool = True
full_precision: bool = True
dtype: DType = DType.float32
eps: Optional[float] = None
elementwise_affine: Optional[bool] = None
bias: Optional[bool] = None
full_precision: Optional[bool] = None
dtype: Optional[DType] = None

def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
Expand All @@ -51,23 +55,22 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
dtype = kwargs["dtype"].as_pt()
kwargs.update(
dict(
size=size,
init_device=init_device,
dtype=dtype,
)
)

if self.name == LayerNormType.default:
return LayerNorm(**kwargs)
elif self.name == LayerNormType.rms:
return RMSNorm(**kwargs)
elif self.name == LayerNormType.fused_rms:
return FusedRMSNorm(**kwargs)
else:
raise NotImplementedError(self.name)
if (dtype := kwargs.pop("dtype", None)) is not None:
kwargs.update(dtype=dtype.as_pt())

try:
if self.name == LayerNormType.default:
return LayerNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.rms:
return RMSNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.fused_rms:
return FusedRMSNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.l2_norm:
return L2Norm(size=size, **kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e


class LayerNorm(nn.Module):
Expand Down Expand Up @@ -245,3 +248,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
None if self.bias is None else self.bias.type_as(x),
eps=self.eps,
).to(og_dtype)


class L2Norm(LayerNorm):
"""
A variant of layer norm that just normalizes the last dimension of the input by its L2 norm,
as done in nGPT.
"""

def __init__(
self,
*,
size: int,
):
super().__init__(size=size, elementwise_affine=False, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return l2_normalize(x)
44 changes: 20 additions & 24 deletions src/olmo_core/nn/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,28 @@ class LMHeadConfig(Config):

name: LMHeadType = LMHeadType.default
layer_norm: Optional[LayerNormConfig] = None
bias: bool = True
bias: Optional[bool] = None
dtype: DType = DType.float32

def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead":
if self.name == LMHeadType.default:
return LMHead(
d_model=d_model,
vocab_size=vocab_size,
layer_norm=self.layer_norm,
dtype=self.dtype.as_pt(),
bias=self.bias,
init_device=init_device,
)
elif self.name == LMHeadType.normalized:
if self.bias:
raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' LM head")
if self.layer_norm is not None:
raise OLMoConfigurationError(f"'layer_norm' is invalid for '{self.name}' LM head")
return NormalizedLMHead(
d_model=d_model,
vocab_size=vocab_size,
dtype=self.dtype.as_pt(),
init_device=init_device,
)
else:
raise NotImplementedError(self.name)
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs.update(
d_model=d_model,
vocab_size=vocab_size,
init_device=init_device,
dtype=kwargs.pop("dtype").as_pt(),
)

try:
if self.name == LMHeadType.default:
return LMHead(**kwargs)
elif self.name == LMHeadType.normalized:
return NormalizedLMHead(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e


class LMHead(nn.Module):
Expand All @@ -75,7 +71,7 @@ def __init__(
*,
d_model: int,
vocab_size: int,
layer_norm: Optional[LayerNormConfig],
layer_norm: Optional[LayerNormConfig] = None,
dtype: torch.dtype = torch.float32,
bias: bool = True,
init_device: str = "cpu",
Expand Down
Loading

0 comments on commit 98960f7

Please sign in to comment.