diff --git a/src/scripts/train/OLMo2-ladder.py b/src/scripts/train/OLMo2-ladder.py index f099bd07..92d891ee 100644 --- a/src/scripts/train/OLMo2-ladder.py +++ b/src/scripts/train/OLMo2-ladder.py @@ -36,7 +36,7 @@ class BaselineModelLadder(ModelLadder): } def get_model_config(self, *, size: ModelSize) -> TransformerConfig: - return getattr(TransformerConfig, f"olmo_{size}")( + return getattr(TransformerConfig, f"olmo2_{size}")( vocab_size=self.tokenizer.padded_vocab_size(), init_seed=self.init_seed, compile=True,