From 257e918ed6ec9c85d4ada04493f52230b90d5e41 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 22 Jan 2025 12:15:41 -0700 Subject: [PATCH] ESM-2 to NeMo checkpoint conversion (#537) Adds a conversion script to convert from huggingface to ESM-2 checkpoints --------- Signed-off-by: Peter St. John --- .../src/bionemo/core/data/resources/esm2.yaml | 70 ++--- .../src/bionemo/esm2/model/convert.py | 179 +++++++++++++ .../src/bionemo/esm2/testing/__init__.py | 14 + .../src/bionemo/esm2/testing/compare.py | 99 +++++++ .../tests/bionemo/esm2/model/test_convert.py | 54 ++++ .../tests/bionemo/esm2/model/test_model.py | 249 +++++++----------- .../src/bionemo/llm/model/config.py | 2 + .../testing/megatron_parallel_state_utils.py | 30 ++- 8 files changed, 497 insertions(+), 200 deletions(-) create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py create mode 100644 sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index d7749aa783..ddc5033b3e 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -7,23 +7,32 @@ description: > A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m. -- tag: nv_3b:2.1 - ngc: "nvidia/clara/esm2nv3b:2.1" +- tag: 8m:2.0 + ngc: nvidia/clara/esm2nv8m:2.0 ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz" - sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret + pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz + sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret owner: Peter St John description: > - An ESM-2 3B model pre-trained on NVIDIA's train/test data split. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model. -- tag: nv_650m:2.1 - ngc: "nvidia/clara/esm2nv650m:2.1" +- tag: 650m:2.0 + ngc: nvidia/clara/esm2nv650m:2.0 ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz" - sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret - owner: Peter St John + pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz" + sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret + owner: Farhad Ramezanghorbani description: > - An ESM-2 650M model pre-trained on NVIDIA's train/test data split. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model. + +- tag: 3b:2.0 + ngc: nvidia/clara/esm2nv3b:2.0 + ngc_registry: model + pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz" + sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret + owner: Farhad Ramezanghorbani + description: > + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model. # - tag: nv_8m:2.1 # ngc: "nvidia/clara/esm2nv8m:2.1" @@ -34,38 +43,29 @@ # description: > # An ESM-2 8M model pre-trained on NVIDIA's train/test data split. -- tag: 8m:2.0 - ngc: "nvidia/clara/esm2nv8m:2.0" +- tag: nv_650m:2.1 + ngc: "nvidia/clara/esm2nv650m:2.1" ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz" - sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret + pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz" + sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret owner: Peter St John description: > - The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format. - -- tag: 650m:2.0 - ngc: nvidia/clara/esm2nv650m:2.0 - ngc_registry: model - pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz" - sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret - owner: Farhad Ramezanghorbani - description: > - The original 650M parameter ESM2 model weights converted to the NeMo2 checkpoint format. + An ESM-2 650M model pre-trained on NVIDIA's train/test data split. -- tag: 3b:2.0 - ngc: nvidia/clara/esm2nv3b:2.0 +- tag: nv_3b:2.1 + ngc: "nvidia/clara/esm2nv3b:2.1" ngc_registry: model - pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz" - sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret - owner: Farhad Ramezanghorbani + pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz" + sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret + owner: Peter St John description: > - The original 3B parameter ESM2 model c converted to the NeMo2 checkpoint format. + An ESM-2 3B model pre-trained on NVIDIA's train/test data split. - tag: fulldata_esm2_pretrain:2.0 ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0 ngc_registry: resource pbss: "s3://general-purpose/esm2/pretrain/2024_03.tar.gz" - sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret + sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret owner: Peter St John description: Full data for ESM2 pretraining. @@ -73,14 +73,14 @@ ngc: nvidia/clara/esm2_pretrain_nemo2_testdata:1.0 ngc_registry: resource pbss: "s3://general-purpose/esm2/pretrain/2024_03_sanity.tar.gz" - sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret + sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret owner: Peter St John description: Test data for ESM2 pretraining. - tag: esm2_inference_testdata:2.0 - ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC + ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC ngc_registry: resource pbss: "s3://bionemo-ci/test_data/esm2/artificial_protein_sequences.csv" - sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret + sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret owner: Farhad Ramezanghorbani description: Test data for ESM2 inference. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py new file mode 100644 index 0000000000..06be1fa0a1 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch +from nemo.lightning import io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf +from transformers import AutoConfig as HFAutoConfig +from transformers import AutoModelForMaskedLM + +from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer +from bionemo.esm2.model.model import ESM2Config +from bionemo.llm.lightning import BionemoLightningModule +from bionemo.llm.model.biobert.lightning import biobert_lightning_module + + +@io.model_importer(BionemoLightningModule, "hf") +class HFESM2Importer(io.ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]): + """Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model.""" + + def init(self) -> BionemoLightningModule: + """Initialize the converted model.""" + return biobert_lightning_module(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Applies the transformation. + + Largely inspired by + https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html + """ + source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted ESM-2 model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + """Converting HF state dict to NeMo state dict.""" + mapping = { + # "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq", + "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight", + "esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias", + "esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight", + "esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias", + "esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight", + "esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias", + "lm_head.dense.weight": "lm_head.dense.weight", + "lm_head.dense.bias": "lm_head.dense.bias", + "lm_head.layer_norm.weight": "lm_head.layer_norm.weight", + "lm_head.layer_norm.bias": "lm_head.layer_norm.bias", + } + + # lm_head.bias + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias], + ) + + @property + def tokenizer(self) -> BioNeMoESMTokenizer: + """We just have the one tokenizer for ESM-2.""" + return get_tokenizer() + + @property + def config(self) -> ESM2Config: + """Returns the transformed ESM-2 config given the model tag.""" + source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) + output = ESM2Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + position_embedding_type="rope", + num_attention_heads=source.num_attention_heads, + seq_length=source.max_position_embeddings, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + + return output + + +@io.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="embedding.word_embeddings.weight", +) +def _pad_embeddings(ctx: io.TransformCTX, source_embed): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_embed.size(0) + num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension + padding_rows = torch.zeros(num_padding_rows, source_embed.size(1)) + return torch.cat((source_embed, padding_rows), dim=0) + + +@io.state_transform( + source_key="lm_head.bias", + target_key="output_layer.bias", +) +def _pad_bias(ctx: io.TransformCTX, source_bias): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_bias.size(0) + output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device) + output_bias[:hf_embedding_dimension] = source_bias + return output_bias + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + np = ctx.target.config.num_attention_heads + # transpose weights + # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] + # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] + concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + concat_weights = concat_weights.view(*input_shape) + return concat_weights + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + np = ctx.target.config.num_attention_heads + # transpose biases + # [num_splits_model_parallel * attention head size * #attention heads] + # --> [attention head size * num_splits_model_parallel * #attention heads] + concat_biases = concat_biases.view(3, np, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + concat_biases = concat_biases.view(*input_shape) + return concat_biases diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py new file mode 100644 index 0000000000..e8690c1d04 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +from pathlib import Path + +import torch +from megatron.core.transformer.module import Float16Module +from transformers import AutoModelForMaskedLM + +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.model import ESM2Config + + +def assert_model_equivalence( + ckpt_path: Path | str, + model_tag: str, + precision: PrecisionTypes = "fp32", + rtol: float | None = None, + atol: float | None = None, +) -> None: + """Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights. + + Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of + the corresponding HuggingFace model. + + Args: + ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model. + model_tag: The HuggingFace model tag for the model to compare against. + precision: The precision type to use for the comparison. Defaults to "fp32". + rtol: The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on + the precision. + atol: The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on + the precision. + """ + tokenizer = get_tokenizer() + + test_proteins = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ] + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + dtype = get_autocast_dtype(precision) + nemo_config = ESM2Config( + initial_ckpt_path=str(ckpt_path), + include_embeddings=True, + include_hiddens=True, + params_dtype=dtype, + pipeline_dtype=dtype, + autocast_dtype=dtype, + bf16=dtype is torch.bfloat16, + fp16=dtype is torch.float16, + ) + + nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() + + if dtype is torch.float16 or dtype is torch.bfloat16: + nemo_model = Float16Module(nemo_config, nemo_model) + + nemo_output = nemo_model(input_ids, attention_mask) + nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] + nemo_hidden_state = nemo_output["hidden_states"] + + del nemo_model + gc.collect() + torch.cuda.empty_cache() + + hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval() + hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) + hf_hidden_state = hf_output_all.hidden_states[-1] + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) + logit_similarity = logit_similarity[attention_mask == 1] + + hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) + hidden_state_similarity = hidden_state_similarity[attention_mask == 1] + + torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) + torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py new file mode 100644 index 0000000000..f3d6d2e691 --- /dev/null +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from nemo.lightning import io + +from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 +from bionemo.esm2.model.model import ESM2Config +from bionemo.esm2.testing.compare import assert_model_equivalence +from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.testing import megatron_parallel_state_utils + + +# pytestmark = pytest.mark.xfail( +# reason="These tests are failing due to a bug in nemo global state when run in the same process as previous " +# "checkpoint save/load scripts." +# ) + + +def test_nemo2_conversion_equivalent_8m(tmp_path): + model_tag = "facebook/esm2_t6_8M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) + + +def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): + model_tag = "facebook/esm2_t6_8M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") + + +@pytest.mark.slow +def test_nemo2_conversion_equivalent_650m(tmp_path): + model_tag = "facebook/esm2_t33_650M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 7d0d20b46b..8895b3719a 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -13,18 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import io import tarfile -from copy import deepcopy -from typing import List, Tuple from unittest import mock import pytest import torch -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from torch import Tensor -from transformers import EsmForMaskedLM +from transformers import AutoModelForMaskedLM from bionemo.core.data.load import load from bionemo.core.utils.dtypes import get_autocast_dtype @@ -33,110 +29,24 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.embedding import ESM2Embedding +from bionemo.esm2.testing.compare import assert_model_equivalence from bionemo.llm.model.biobert.model import MegatronBioBertModel from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping from bionemo.testing import megatron_parallel_state_utils -def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor: - """reduce last layer's hidden values to embeddings - - Args: - hiddens: [b, s, h] tensor of hidden values - attention_mask: [b, s] attention mask tensor - - Returns: - reduced embedding tensor [b, h] - """ - masks = torch.sum(attention_mask, dim=1) - embeddings = torch.zeros( - size=(hiddens.shape[0], hiddens.shape[2]), - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - for i, (hidden, mask) in enumerate(zip(hiddens, masks)): - embeddings[i, :] = torch.mean(hidden[1 : mask - 1], dim=0) - return embeddings - - -@pytest.fixture(scope="module") -def esm2_config() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config() - - -@pytest.fixture(scope="module") -def esm2_650M_config_w_ckpt() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config(initial_ckpt_path=load("esm2/650m:2.0")) - - -@pytest.fixture(scope="module") -def esm2_model(esm2_config) -> ESM2Model: +def test_esm2_model_initialized(): with megatron_parallel_state_utils.distributed_model_parallel_state(): tokenizer = get_tokenizer() - model = esm2_config.configure_model(tokenizer) - yield model - - -@pytest.fixture(scope="module") -def sample_data() -> List[Tuple[str, str]]: - """Generates sample protein sequences for sanity checks, including mask tokens.""" - max_length = 1022 # The maximum length of the protein sequences to be considered. - sample_data = [ - ( - "protein1", - "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA", - ), - ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), - ( - "protein3", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", - ), - ( - "protein4", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA", - ), - ] - # add another sample protein that uses the maximum length to test this edge case - sample_data.append(("protein5", (sample_data[0][1] * 3)[:max_length])) - yield sample_data + config = ESM2Config() + model = config.configure_model(tokenizer) + assert isinstance(model, MegatronBioBertModel) + assert isinstance(model, ESM2Model) + assert isinstance(model.embedding, ESM2Embedding) -def _compute_loss(model, dataloader, vocab_size=None): - loss = 0 - n = 0 - limit_batches = 10 - for i, batch in enumerate(dataloader): - assert isinstance(batch, dict) - result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) - # bionemo ESM2 vocab_size - if vocab_size is not None: - # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. - logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] - else: - logits = result.logits - - loss_mask = batch["loss_mask"].cuda() - target = batch["labels"].cuda() - - loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") - n += loss_mask.sum() - - if limit_batches is not None and i + 1 >= limit_batches: - break - mean_loss: Tensor = loss / n - return mean_loss - - -def test_esm2_model_initialized(esm2_model): - assert isinstance(esm2_model, MegatronBioBertModel) - assert isinstance(esm2_model, ESM2Model) - assert isinstance(esm2_model.embedding, ESM2Embedding) - - -def test_esm2_650m_checkpoint(esm2_model): +def test_esm2_nemo1_checkpoint(): with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad(): ckpt_file = ckpt.extractfile("./model_weights.ckpt") @@ -145,10 +55,14 @@ def test_esm2_650m_checkpoint(esm2_model): # TODO: update Bionemo checkpoints old_state_dict.pop("model.language_model.rotary_pos_emb.inv_freq") - new_state_dict = esm2_model.state_dict_for_save_checkpoint() + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = get_tokenizer() + config = ESM2Config() + model = config.configure_model(tokenizer) + new_state_dict = model.state_dict_for_save_checkpoint() - # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module which stores a copy of - # this model into self.module + # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module + # which stores a copy of this model into self.module old_keys = { nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=True) for k in old_state_dict } @@ -176,56 +90,41 @@ def test_esm2_650m_checkpoint(esm2_model): assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." -def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): - tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") - tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to("cuda") - input_ids = tokens["input_ids"] - attention_mask = tokens["attention_mask"] - - # HF 650M model - hf_model = EsmForMaskedLM.from_pretrained( - "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) - ).cuda() - - with torch.no_grad(): - hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) - hf_logits = hf_output_all.logits * attention_mask.unsqueeze(-1) - hf_embeddings = reduce_hiddens(hf_output_all.hidden_states[-1], attention_mask) - - # free GPU RAM - del hf_model - gc.collect() - torch.cuda.empty_cache() - - # configure the model to return logits - model = esm2_650M_config_w_ckpt.configure_model(get_tokenizer()).cuda() - model.eval() - result = model(input_ids, attention_mask) - # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. - logits = result["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] - logits = logits * attention_mask.unsqueeze(-1) # incorporate masking logic - - # free GPU RAM - del model - gc.collect() - torch.cuda.empty_cache() - - # configure the model to return hiddens - esm2_650M_config_hiddens = deepcopy(esm2_650M_config_w_ckpt) - esm2_650M_config_hiddens.set_hparam("return_only_hidden_states", True) - model = esm2_650M_config_hiddens.configure_model(get_tokenizer()).cuda() - model.eval() - hiddens = model(input_ids, attention_mask) - embeddings = reduce_hiddens(torch.transpose(hiddens, 0, 1).float(), attention_mask) - - torch.testing.assert_close(logits, hf_logits, atol=0.2, rtol=0.0) - torch.testing.assert_close(embeddings, hf_embeddings, atol=5e-3, rtol=0.0) - - -def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): +def _compute_loss(model, dataloader, vocab_size=None): + loss = 0 + n = 0 + limit_batches = 10 + for i, batch in enumerate(dataloader): + assert isinstance(batch, dict) + result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) + + # bionemo ESM2 vocab_size + if vocab_size is not None: + # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. + logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] + else: + logits = result.logits + + loss_mask = batch["loss_mask"].cuda() + target = batch["labels"].cuda() + + loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") + n += loss_mask.sum() + + if limit_batches is not None and i + 1 >= limit_batches: + break + mean_loss: Tensor = loss / n + return mean_loss + + +def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs): + hf_model_tag = "facebook/esm2_t6_8M_UR50D" + nv_model_tag = "esm2/8m:2.0" + # hf_model_tag = "facebook/esm2_t33_650M_UR50D" + # nv_model_tag = "esm2/650m:2.0" + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs - compute_hf_reference: bool = True seed: int = 42 with ( @@ -235,8 +134,8 @@ def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet ): tokenizer = get_tokenizer() - # ESM2 model initialized with 650M params - model = esm2_650M_config_w_ckpt.configure_model(tokenizer).cuda() + # ESM2 model initialized with params + model = ESM2Config(initial_ckpt_path=str(load(nv_model_tag))).configure_model(tokenizer).cuda() # Initialize the data module. data_module = ESMDataModule( @@ -268,14 +167,42 @@ def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet mean_loss = _compute_loss(model, train_dataloader, vocab_size=tokenizer.vocab_size) - if compute_hf_reference: - # HF model initialized with 650M params - hf_model = EsmForMaskedLM.from_pretrained( - "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) - ).cuda() - hf_mean_loss = _compute_loss(hf_model, train_dataloader) - print(f"hf_mean_loss: {hf_mean_loss}") - else: - hf_mean_loss = torch.tensor(2.9279041290283203).cuda() + # HF model initialized with params + hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() + hf_mean_loss = _compute_loss(hf_model, train_dataloader) + print(f"hf_mean_loss: {hf_mean_loss}") torch.testing.assert_close(mean_loss, hf_mean_loss, atol=1e-3, rtol=0.0) + + +@pytest.mark.parametrize("precision", ["fp32", "bf16", "fp16", "bf16-mixed"]) +def test_model_equivalence_with_huggingface_8m(precision): + model_tag = "facebook/esm2_t6_8M_UR50D" + ckpt_path = load("esm2/8m:2.0") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): + assert_model_equivalence(ckpt_path, model_tag, precision=precision) + + +@pytest.mark.slow +def test_model_equivalence_with_huggingface_650m(): + model_tag = "facebook/esm2_t33_650M_UR50D" + ckpt_path = load("esm2/650m:2.0") + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + + +@pytest.mark.slow +def test_model_equivalence_with_huggingface_650m_bf16(): + model_tag = "facebook/esm2_t33_650M_UR50D" + ckpt_path = load("esm2/650m:2.0") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): + assert_model_equivalence(ckpt_path, model_tag, precision="bf16") + + +@pytest.mark.slow +@pytest.mark.skip(reason="This test triggers a large download from huggingface and requires considerable GPU memory.") +def test_model_equivalence_with_huggingface_3b(): + model_tag = "facebook/esm2_t36_3B_UR50D" + ckpt_path = load("esm2/3b:2.0") + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py index 425726da48..bf48a520c1 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py @@ -45,6 +45,8 @@ "initial_ckpt_path_ignore_weights", "initial_ckpt_path", "model_cls", + "bf16", + "fp16", ] OVERRIDE_BIONEMO_CONFIG_DEFAULTS = deepcopy(_OVERRIDE_BIONEMO_CONFIG_DEFAULTS) # copy for export diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py index 1686de309d..8ef0762239 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py @@ -42,9 +42,12 @@ def my_test(): import torch.distributed from megatron.core import parallel_state from megatron.core.tensor_parallel import random as tp_random +from nemo import lightning as nl from nemo.utils import logging from torch.testing._internal.distributed.fake_pg import FakeStore +from bionemo.core.utils.dtypes import PrecisionTypes + __all__: Sequence[str] = ( "clean_parallel_state_context", @@ -81,12 +84,24 @@ def _initialize_distributed_parallel_state( pipeline_model_parallel_split_rank: int = 0, context_parallel_size: int = 1, interactive: bool = False, -) -> None: + precision: PrecisionTypes = "fp32", +) -> pl.Trainer | None: + trainer = None # initialize pytorch DDP # if not interactive and not torch.distributed.is_initialized(): if not torch.distributed.is_initialized(): - logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightening...") - trainer = pl.Trainer(devices=devices, strategy="ddp" if not interactive else "auto", num_nodes=1) + logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightning...") + trainer = pl.Trainer( + devices=devices, + strategy="ddp" if not interactive else "auto", + num_nodes=1, + # plugins=nl.MegatronMixedPrecision( + # precision=precision, + # params_dtype=get_autocast_dtype(precision), + # pipeline_dtype=get_autocast_dtype(precision), + # autocast_enabled=False, + # ), + ) if trainer.strategy.launcher is not None: trainer.strategy.launcher.launch(_dummy, trainer=trainer) @@ -101,6 +116,8 @@ def _initialize_distributed_parallel_state( context_parallel_size=context_parallel_size, ) + return trainer + @contextmanager def clean_parallel_state_context() -> Iterator[None]: @@ -124,6 +141,7 @@ def distributed_model_parallel_state( pipeline_model_parallel_split_rank: int = 0, context_parallel_size: int = 1, interactive: bool = False, + precision: PrecisionTypes = "fp32", ) -> Iterator[None]: """Context manager for handling creating and cleaning up distributed model parallel state for tests. Use like: @@ -132,16 +150,18 @@ def distributed_model_parallel_state( # After the block your state is cleaned up. """ # noqa: D205 initial_states: Optional[Any] = None + trainer: pl.Trainer | None = None try: _teardown_apex_megatron_cuda() - _initialize_distributed_parallel_state( + trainer = _initialize_distributed_parallel_state( devices=devices, tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, context_parallel_size=context_parallel_size, interactive=interactive, + precision=precision, ) # Our goal is to set required state on entry, and then restore current state on exit for the RNGs. # there are two possibilities that are handled below: @@ -174,6 +194,8 @@ def distributed_model_parallel_state( # Reset to the unset state tp_random.get_cuda_rng_tracker().reset() _teardown_apex_megatron_cuda() + if trainer is not None: + nl.teardown(trainer) @contextmanager