Skip to content

Commit

Permalink
fix model builders
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 27, 2024
1 parent bd28f43 commit 3ce2f25
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
21 changes: 13 additions & 8 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def num_flops_per_token(self, seq_len: int) -> int:
return flop_per_token

@classmethod
def olmo_190M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_190M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
return cls.llama_like(
d_model=768,
hidden_size_multiplier=1.5,
Expand All @@ -304,10 +304,11 @@ def olmo_190M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
layer_norm_eps=1e-6,
**kwargs,
)

@classmethod
def olmo_370M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_370M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
return cls.llama_like(
d_model=1024,
hidden_size_multiplier=1.4,
Expand All @@ -318,10 +319,11 @@ def olmo_370M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
layer_norm_eps=1e-6,
**kwargs,
)

@classmethod
def olmo_600M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_600M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
return cls.llama_like(
d_model=1344,
hidden_size_multiplier=1.5,
Expand All @@ -332,10 +334,11 @@ def olmo_600M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
layer_norm_eps=1e-6,
**kwargs,
)

@classmethod
def olmo_760M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_760M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
return cls.llama_like(
d_model=1536,
hidden_size_multiplier=1.5,
Expand All @@ -346,10 +349,11 @@ def olmo_760M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
layer_norm_eps=1e-6,
**kwargs,
)

@classmethod
def olmo_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 1B OLMo model config.
"""
Expand All @@ -363,7 +367,7 @@ def olmo_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
return cls.llama_like(
d_model=3328,
hidden_size_multiplier=1.4,
Expand All @@ -374,10 +378,11 @@ def olmo_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
layer_norm_eps=1e-6,
**kwargs,
)

@classmethod
def olmo_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 7B OLMo model config.
"""
Expand All @@ -391,7 +396,7 @@ def olmo_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 13B OLMo model config.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMo2-13B.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_13B(
return TransformerConfig.olmo2_13B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=TransformerDataParallelConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMo2-1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_1B(
return TransformerConfig.olmo2_1B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=TransformerDataParallelConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMo2-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_7B(
return TransformerConfig.olmo2_7B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=TransformerDataParallelConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMoE-1B-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def build_model_config(common: CommonComponents) -> TransformerConfig:
model_config = TransformerConfig.olmo_1B(
model_config = TransformerConfig.olmo2_1B(
vocab_size=common.tokenizer.padded_vocab_size(),
n_layers=16,
n_heads=16,
Expand Down

0 comments on commit 3ce2f25

Please sign in to comment.