diff --git a/docs/source/examples/huggingface.rst b/docs/source/examples/huggingface.rst index 68d88446..01baa543 100644 --- a/docs/source/examples/huggingface.rst +++ b/docs/source/examples/huggingface.rst @@ -7,7 +7,7 @@ One way to do this would be to manually apply a data parallel wrapper (like DDP Instead we recommend converting your HuggingFace checkpoint into a format that can be loaded into an equivalent OLMo-core :class:`~olmo_core.nn.transformer.Transformer` model, when possible, using the functions provided by :mod:`olmo_core.distributed.checkpoint`. -Below is an example that shows how to convert a Llama-3.2 checkpoint on HuggingFace into the right format for OLMo-core. +Below is an example that shows how to convert an OLMo2 or Llama-3 checkpoint on HuggingFace into the right format for OLMo-core. It would be straight forward to adapt this script to convert in the other direction as well. .. seealso:: diff --git a/src/examples/huggingface/convert_checkpoint.py b/src/examples/huggingface/convert_checkpoint.py index 34ed07e1..6559ee43 100644 --- a/src/examples/huggingface/convert_checkpoint.py +++ b/src/examples/huggingface/convert_checkpoint.py @@ -1,8 +1,8 @@ """ -Example script showing how you could convert model weights on HuggingFace for a Llama-3.2 model -into a format that can be loaded by OLMo-core for fine-tuning. +Example script showing how you could convert model weights on HuggingFace for an OLMo2 or Llama-3.* +model into a format that can be loaded by OLMo-core for fine-tuning. -Note that this script is architecture-dependent, meaning it may only work for Llama-3.2 models on +Note that this script is architecture-dependent, meaning it may only work for OLMo2/Llama models on HuggingFace. """ @@ -20,14 +20,38 @@ log = logging.getLogger(__name__) -HF_MODEL = "meta-llama/Llama-3.2-1B" +HF_MODEL = "allenai/OLMo-2-1124-7B" +# HF_MODEL = "allenai/OLMo-2-1124-7B-Instruct" +# HF_MODEL = "allenai/OLMo-2-1124-13B-Instruct" +# HF_MODEL = "meta-llama/Llama-3.2-1B" +# HF_MODEL = "meta-llama/Llama-3.2-8B" + SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}" SAVE_OVERWRITE = False TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL) -MODEL_CONFIG = TransformerConfig.llama3_1B( - TOKENIZER_CONFIG.vocab_size, fused_ops=False, use_flash=False, rope_scaling=RoPEScalingConfig() -) +MODEL_CONFIG: TransformerConfig +if HF_MODEL == "meta-llama/Llama-3.2-1B": + MODEL_CONFIG = TransformerConfig.llama3_1B( + TOKENIZER_CONFIG.vocab_size, + fused_ops=False, + use_flash=False, + rope_scaling=RoPEScalingConfig(), + ) +elif HF_MODEL.startswith("allenai/OLMo-2-1124-7B"): + MODEL_CONFIG = TransformerConfig.olmo2_7B( + TOKENIZER_CONFIG.vocab_size, + fused_ops=False, + use_flash=False, + ) +elif HF_MODEL.startswith("allenai/OLMo-2-1124-13B"): + MODEL_CONFIG = TransformerConfig.olmo2_13B( + TOKENIZER_CONFIG.vocab_size, + fused_ops=False, + use_flash=False, + ) +else: + raise NotImplementedError(HF_MODEL) def convert_checkpoint() -> AutoModelForCausalLM: @@ -78,15 +102,27 @@ def convert_checkpoint() -> AutoModelForCausalLM: f"model.layers.{block}.mlp.up_proj.weight" ) - # Attention layer norm. - new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop( - f"model.layers.{block}.input_layernorm.weight" - ) - - # MLP layer norm. - new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( - f"model.layers.{block}.post_attention_layernorm.weight" - ) + # Layer norms. + if "Llama" in HF_MODEL: + new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( + f"model.layers.{block}.post_attention_layernorm.weight" + ) + new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{block}.input_layernorm.weight" + ) + else: + new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{block}.post_attention_layernorm.weight" + ) + new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( + f"model.layers.{block}.post_feedforward_layernorm.weight" + ) + new_state_dict[f"blocks.{block}.attention.q_norm.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.q_norm.weight" + ) + new_state_dict[f"blocks.{block}.attention.k_norm.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.k_norm.weight" + ) assert len(state_dict) == 0 @@ -97,22 +133,26 @@ def convert_checkpoint() -> AutoModelForCausalLM: def validate_conversion(hf_model): - log.info("Loading converted checkpoint for validation...") - device = get_default_device() - model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval() - load_model_and_optim_state(SAVE_PATH, model) + B, T = 1, 120 + input_ids = torch.randint(0, TOKENIZER_CONFIG.vocab_size, (B, T)).to(device) hf_model = hf_model.to(device).eval() + with torch.no_grad(): + hf_logits, *_ = hf_model(input_ids=input_ids, return_dict=False) - B, T = 1, 120 - input_ids = torch.randint(0, TOKENIZER_CONFIG.vocab_size, (B, T)).to(device) + del hf_model + + model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval() + + log.info("Loading converted checkpoint for validation...") + load_model_and_optim_state(SAVE_PATH, model) with torch.no_grad(): logits = model(input_ids=input_ids) - hf_logits, *_ = hf_model(input_ids=input_ids, return_dict=False) - torch.testing.assert_close(hf_logits, logits) + + torch.testing.assert_close(hf_logits, logits) log.info("Conversion successful")