diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model/model.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py similarity index 100% rename from sub-packages/bionemo-amplify/src/bionemo/amplify/model/model.py rename to sub-packages/bionemo-amplify/src/bionemo/amplify/model.py diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_TODO_bionemo_amplify.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py similarity index 62% rename from sub-packages/bionemo-amplify/tests/bionemo/amplify/test_TODO_bionemo_amplify.py rename to sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py index b81a5c378e..95cb691c51 100644 --- a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_TODO_bionemo_amplify.py +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py @@ -14,11 +14,12 @@ # limitations under the License. -from pytest import fixture, mark, raises +import transformers +from nemo.lightning.io import IOMixin -def test_todo() -> None: - raise ValueError( - f"Implement tests! Make use of {fixture} for data, {raises} to check for " - f"exceptional cases, and {mark} as needed" - ) +class BioNeMoAMPLIFYTokenizer(transformers.PreTrainedTokenizerFast, IOMixin): # noqa D101 + def __init__(self): + """A wrapper to make AutoTokenizer serializable for the ESM2 tokenizer.""" + other = transformers.AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", use_fast=True) + self.__dict__.update(dict(other.__dict__)) diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py similarity index 51% rename from sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py rename to sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py index 25e6abfbc5..6815aa6a36 100644 --- a/sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py @@ -12,3 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +import pytest +import torch +from nemo.lightning import io + +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer + + +@pytest.fixture +def tokenizer(): + return BioNeMoAMPLIFYTokenizer() + + +def test_tokenizer_serialization(tokenizer, tmp_path): + tokenizer.io_dump(tmp_path / "tokenizer", yaml_attrs=[]) # BioNeMoESMTokenizer takes no __init__ arguments + deserialized_tokenizer = io.load(tmp_path / "tokenizer", tokenizer.__class__) + + our_tokens = deserialized_tokenizer.encode("K A I S Q", add_special_tokens=False) + esm_tokens = torch.tensor([15, 5, 32, 12, 8, 16]) + torch.testing.assert_close(torch.tensor(our_tokens), esm_tokens)