Skip to content

Commit

Permalink
fix!
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 22, 2024
1 parent 06964a7 commit 991c6eb
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 79 deletions.
39 changes: 38 additions & 1 deletion src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,41 @@ class AttentionConfig(Config):
use_flash: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
n_heads = self.n_heads
n_kv_heads = self.n_kv_heads or n_heads
head_dim = d_model // n_heads
bias = self.bias if self.bias is not None else self.name != AttentionType.normalized

params = 0

# Block attention Q projection.
params += d_model * d_model
if bias:
params += d_model

# Block attention KV projections.
params += 2 * d_model * n_kv_heads * head_dim
if bias:
params += 2 * n_kv_heads * head_dim

# Block attention QK norm.
if self.qk_norm is not None:
params += 2 * self.qk_norm.num_params(d_model)

# Block attention out.
params += d_model * d_model
if bias:
params += d_model

# Block QK scaling factors.
if self.name == AttentionType.normalized:
head_dim = d_model // n_heads
params += n_heads * head_dim
params += n_kv_heads * head_dim

return params

def build(
self,
d_model: int,
Expand Down Expand Up @@ -90,7 +125,9 @@ def build(
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class Attention(nn.Module):
Expand Down
19 changes: 18 additions & 1 deletion src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ class FeedForwardConfig(Config):
bias: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
bias = self.bias if self.bias is not None else self.name != FeedForwardType.normalized

params = 0

params += 3 * d_model * self.hidden_size
if bias:
params += 2 * self.hidden_size + d_model

# w1 + w3 scaling factors
if self.name == FeedForwardType.normalized:
params += 2 * self.hidden_size

return params

def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward":
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
Expand All @@ -55,7 +70,9 @@ def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward":
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class FeedForward(nn.Module):
Expand Down
18 changes: 17 additions & 1 deletion src/olmo_core/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ class LayerNormConfig(Config):
full_precision: Optional[bool] = None
dtype: Optional[DType] = None

def num_params(self, size: int) -> int:
elementwise_affine = (
self.elementwise_affine
if self.elementwise_affine is not None
else self.name != LayerNormType.l2_norm
)
bias = self.bias if self.bias is not None else self.name != LayerNormType.l2_norm
ln_params = 0
if elementwise_affine:
ln_params += size
if bias:
ln_params += size
return ln_params

def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
Construct the corresponding LayerNorm class.
Expand All @@ -70,7 +84,9 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class LayerNorm(nn.Module):
Expand Down
21 changes: 20 additions & 1 deletion src/olmo_core/nn/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ class LMHeadConfig(Config):
bias: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int, vocab_size: int) -> int:
bias = self.bias if self.bias is not None else self.name != LMHeadType.normalized

params = 0
if self.layer_norm is not None:
params += self.layer_norm.num_params(d_model)

params += d_model * vocab_size
if bias:
params += vocab_size

# Final scaling factor.
if self.name == LMHeadType.normalized:
params += vocab_size

return params

def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead":
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
Expand All @@ -58,7 +75,9 @@ def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class LMHead(nn.Module):
Expand Down
27 changes: 16 additions & 11 deletions src/olmo_core/nn/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from ..config import Config, StrEnum
from ..exceptions import OLMoConfigurationError
from .buffer_cache import BufferCache

__all__ = [
Expand Down Expand Up @@ -95,17 +96,21 @@ def build(
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs["head_shape"] = head_shape
kwargs["cache"] = cache

if self.name == "default":
return RotaryEmbedding(**kwargs)
elif self.name == "fused":
return FusedRotaryEmbedding(**kwargs)
elif self.name == "complex":
return ComplexRotaryEmbedding(**kwargs)
else:
raise NotImplementedError(self.name)
kwargs.update(head_shape=head_shape, cache=cache)

try:
if self.name == "default":
return RotaryEmbedding(**kwargs)
elif self.name == "fused":
return FusedRotaryEmbedding(**kwargs)
elif self.name == "complex":
return ComplexRotaryEmbedding(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class RotaryEmbeddingBase(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def build(
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(f"invalid options for '{self.name}', {e}") from e
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class TransformerBlockBase(nn.Module):
Expand Down
73 changes: 10 additions & 63 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,92 +228,39 @@ def num_params(self) -> int:
The total number of parameters that a model from this config would have.
"""

def layer_norm_params(layer_norm: LayerNormConfig) -> int:
ln_params = 0
if layer_norm.elementwise_affine:
ln_params += self.d_model
if layer_norm.bias:
ln_params += self.d_model
return ln_params

num_params = 0

# Embedding params.
num_params += self.d_model * self.vocab_size

block_params = 0

n_heads = self.block.attention.n_heads
n_kv_heads = self.block.attention.n_kv_heads or n_heads
head_dim = self.d_model // n_heads

# Block attn and MLP scaling factors.
if self.block.name == TransformerBlockType.normalized:
block_params += 2 * self.d_model

# Block attention Q projection.
block_params += self.d_model * self.d_model
if self.block.attention.bias:
block_params += self.d_model

# Block attention KV projections.
block_params += 2 * self.d_model * n_kv_heads * head_dim
if self.block.attention.bias:
block_params += 2 * n_kv_heads * head_dim

# Block attention QK norm.
if self.block.attention.qk_norm is not None:
block_params += 2 * layer_norm_params(self.block.attention.qk_norm)

# Block attention out.
block_params += self.d_model * self.d_model
if self.block.attention.bias:
block_params += self.d_model
# Block attention params.
block_params += self.block.attention.num_params(self.d_model)

# Block attention norm.
if self.block.layer_norm is not None:
block_params += layer_norm_params(self.block.layer_norm)

# Block QK scaling factors.
if self.block.attention.name == AttentionType.normalized:
head_dim = self.d_model // self.block.attention.n_heads
block_params += self.block.attention.n_heads * head_dim
block_params += (
self.block.attention.n_kv_heads or self.block.attention.n_heads
) * head_dim
block_params += self.block.layer_norm.num_params(self.d_model)

# Block feed forward.
if "moe" not in self.block.name:
assert self.block.feed_forward is not None
block_params += 3 * self.d_model * self.block.feed_forward.hidden_size
if self.block.feed_forward.bias:
block_params += 2 * self.block.feed_forward.hidden_size + self.d_model
# w1 + w3 scaling factors
if self.block.feed_forward.name == FeedForwardType.normalized:
block_params += 2 * self.block.feed_forward.hidden_size
else:
assert self.block.feed_forward_moe is not None
if self.block.feed_forward is not None:
block_params += self.block.feed_forward.num_params(self.d_model)
elif self.block.feed_forward_moe is not None:
block_params += self.block.feed_forward_moe.num_params(self.d_model)

# Block feed forward norm.
if self.block.layer_norm is not None:
block_params += layer_norm_params(self.block.layer_norm)
block_params += self.block.layer_norm.num_params(self.d_model)

# All block params.
num_params += self.n_layers * block_params

# Final layer norm.
if self.lm_head.layer_norm is not None:
num_params += layer_norm_params(self.lm_head.layer_norm)

# Final FF out.
num_params += self.d_model * self.vocab_size
if self.lm_head.bias:
num_params += self.vocab_size

# Final scaling factor.
if self.name == TransformerType.normalized:
num_params += self.vocab_size
# LM head.
num_params += self.lm_head.num_params(self.d_model, self.vocab_size)

return num_params

Expand Down Expand Up @@ -706,7 +653,7 @@ def ngpt_like(
dtype=dtype,
),
feed_forward=FeedForwardConfig(
name=FeedForwardType.normalized, hidden_size=hidden_size, bias=False, dtype=dtype
name=FeedForwardType.normalized, hidden_size=hidden_size, dtype=dtype
),
)

Expand Down
5 changes: 5 additions & 0 deletions src/test/nn/transformer/model_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest
import torch
import torch.nn as nn
Expand All @@ -7,6 +9,8 @@

from ...utils import GPU_MARKS

log = logging.getLogger(__name__)


@pytest.mark.parametrize(
"init_device, device",
Expand All @@ -17,6 +21,7 @@
)
def test_small_llama2_config_builder(init_device, device):
config = TransformerConfig.llama2_271M(vocab_size=50257)
log.info(config)
model = config.build(init_device=init_device, device=torch.device(device))

# Make sure num params estimate is correct.
Expand Down

0 comments on commit 991c6eb

Please sign in to comment.