Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
db33a6a
add codonfm 5b arch params
balvisio Feb 13, 2026
73d5b55
add script to download data
balvisio Feb 24, 2026
ac2895d
updated data download
balvisio Feb 24, 2026
f35dd90
updated nptebook
balvisio Feb 25, 2026
276866b
added download scripts
balvisio Feb 25, 2026
d7c7bd1
Improve check_codon_frequency
balvisio Feb 25, 2026
edf20bb
fix memmap creator
balvisio Feb 25, 2026
617e82f
added new notebook
balvisio Feb 25, 2026
60c6422
updated nb
balvisio Feb 26, 2026
6ed39b7
updated notebook
balvisio Feb 26, 2026
ef722b6
fix small config bug
balvisio Feb 27, 2026
a3142e2
fix runner missing param
balvisio Feb 28, 2026
a661b61
added bendchmarking target
balvisio Mar 2, 2026
d252fac
fix build
balvisio Mar 3, 2026
b59b203
Fixed dockerfile
balvisio Mar 3, 2026
606a833
typo
balvisio Mar 3, 2026
e4734b6
Update table
balvisio Mar 3, 2026
a43c18a
Fix thd bug
balvisio Mar 7, 2026
6a62791
fix 2nd bug
balvisio Mar 7, 2026
4175073
add batch waste logs
balvisio Mar 7, 2026
fb0b9b5
add warmup to start logging
balvisio Mar 7, 2026
cf4aac5
add log
balvisio Mar 8, 2026
3615b04
Make benchmark build more efficient
balvisio Mar 8, 2026
6032717
new changes to notebook
balvisio Mar 9, 2026
7f1fdf5
fix bug in log on batch end
balvisio Mar 10, 2026
7b8bb9b
Add throughput logger
balvisio Mar 11, 2026
8d8ed00
support variable numbers of samples batches
balvisio Mar 12, 2026
9271280
support THD with stateful set
balvisio Mar 12, 2026
f4d0e64
add fix to setup
balvisio Mar 12, 2026
f5fc8f1
Refactor and fix shuffle edge case
balvisio Mar 13, 2026
0229ab4
add throughput hooks for validation
balvisio Mar 13, 2026
0d085eb
fix small parsing bug
balvisio Mar 13, 2026
3c87562
support thd in inference
balvisio Mar 14, 2026
54b6e08
add support for loading models from HF hub
balvisio Mar 14, 2026
8a2e59a
Add and fix documentation
balvisio Mar 16, 2026
a8bd5b5
Update benchmarks
balvisio Mar 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions bionemo-recipes/recipes/codonfm_ptl_te/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,18 @@ RUN chown -R ${USERNAME:-vscode}:${USERNAME:-vscode} /workspace/codonfm

# Switch to the non-root user
USER $USERNAME

# ----------------- For benchmarking only -----------------
# Warning: I was only able to build this image in an instance with 2TB of memory.
# Otherwise, a segmentation fault occurs during the build process:
# /bin/bash: line 1: 13517 Segmentation fault (core dumped) ptxas -arch=sm_90 -m64 -v --generate-line-info "/tmp/tmpxft_00002f58_00000000-6_flash_fwd_hdim64_256_fp16_paged_split_sm90.ptx" -o "/tmp/tmpxft_00002f58_00000000-8_flash_fwd_hdim64_256_fp16_paged_split_sm90.cubin" > /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stdout 2> /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stderr
#
# Could have also been caused by CUDA-compatibility issues.

FROM base AS benchmarking

WORKDIR /workspace/codonfm

RUN pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@v0.0.32.post2#egg=xformers --no-deps

COPY . .
33 changes: 29 additions & 4 deletions bionemo-recipes/recipes/codonfm_ptl_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The table below summarizes the set of open source pre-trained weights currently
| EnCodon 600M | MLM (random p=0.15) | 2048 | 12 | 16 | 8192 | `mlm/encodon_600m.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-600M-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-600M-v1) |
| EnCodon 1B | MLM (random p=0.15) | 2048 | 18 | 16 | 8192 | `mlm/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-1B-v1) |
| EnCodon 1B (CDSWT) | MLM (codon frequency-weighted) | 2048 | 18 | 16 | 8192 | `cdswt/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1) |
| EnCodon 5B | MLM (codon p=0.15) | 4096 | 24 | 32 | 16384 | `mlm/encodon_5b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-5B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-5B-v1) |
| EnCodon 5B (CDSWT) | MLM (codon frequency-weighted) | 4096 | 24 | 32 | 16384 | `cdswt/encodon_5b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-5B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-5B-v1) |

## Repository Structure

Expand Down Expand Up @@ -75,12 +77,29 @@ We also present the ability to utilize a simpler model architecture that directl

<br>

The training step speedups for the 80M Encodon model when both Transformer Engine (TE) and Sequence Packing (THD) are applied compared to the Xformers based model are shown below. We benchmarked on NVIDIA H100 80GB HBM3 GPUs using a micro batch-size is 32. The training step speedups for the 1B Encodon model are on a micro batch-size of 4.
The figure below shows training throughput speedups, derived from `tokens/s/gpu`, for the `80M` and `1B` Encodon models when Transformer Engine (TE) and sequence packing (THD) are applied relative to the Xformers-based baseline.

![xf](assets/images/training_acceleration_plot.png)

For inferencing, we can also demonstrate acceleration when using each models TE counterpart. Thus, a 1.4X speedup in this chart shows how much faster the TE version of the model is over the original baseline PyTorch SDPA model.
![i](assets/images/inference_plot.png)
All training experiments reported here were run on `8 x NVIDIA H100 80GB HBM3` GPUs in `bfloat16` precision. The absolute throughputs used to compute the speedups above are reported below in `tokens/s/gpu`.

| Model | Xformers (`tokens/s/gpu`) | SDPA (`tokens/s/gpu`) | TE-BSHD (`tokens/s/gpu`) | TE-THD (`tokens/s/gpu`) | Speedup over baseline |
| ----- | ------------------------: | --------------------: | -----------------------: | ----------------------: | ----------------------------- |
| 80M | 117119 | 145357 | 419087 | 1028891 | 1.00x / 1.24x / 3.58x / 9.79x |
| 1B | 8698 | 9899 | 26476 | 69300 | 1.00x / 1.14x / 3.04x / 7.97x |
| 5B | 2320 | 2865 | 5112 | 13973 | 1.00x / 1.23x / 2.20x / 6.02x |

For inference, we report both relative speedup and absolute throughput. The figure below compares inference configurations by relative speedup within each model size.

![Inference speedup across model sizes](assets/images/inference_plot.png)

All inference experiments reported here were run on `8 x NVIDIA H100 80GB HBM3` GPUs in `bfloat16` precision. The absolute throughputs used to compute the speedups above are reported below in `tokens/s/gpu`.

| Model | Xformers (`tokens/s/gpu`) | SDPA (`tokens/s/gpu`) | TE-BSHD (`tokens/s/gpu`) | TE-THD (`tokens/s/gpu`) | Speedup over baseline |
| ----- | ------------------------: | --------------------: | -----------------------: | ----------------------: | ------------------------------ |
| 80M | 156819 | 190380 | 542147 | 1875140 | 1.00x / 1.21x / 3.46x / 11.96x |
| 1B | 18655 | 21715 | 46551 | 221110 | 1.00x / 1.16x / 2.50x / 11.85x |
| 5B | 5316 | 5991 | 9996 | 40373 | 1.00x / 1.13x / 1.88x / 7.59x |

## Quickstart

Expand Down Expand Up @@ -185,9 +204,11 @@ Optional path overrides:
```bash
--out_dir <dir>
--checkpoints_dir <dir>
--pretrained_ckpt_path <path>
```

- `--out_dir`: Base output directory for logs, metrics, and other artifacts. Defaults to `results/`.
- `--checkpoints_dir`: Directory where training checkpoints are saved. Defaults to `<out_dir>/checkpoints/`. This directory also enables **automatic resumption**: if the runner finds a `last.ckpt` file inside this directory, it will reload the model weights and full trainer state (optimizer, learning-rate schedule, global step, etc.) so training picks up exactly where it left off. This is essential for long pretraining runs on clusters where jobs may be preempted or interrupted. On a fresh run the directory will be empty, so training starts from scratch as expected.

For multi-node execution consider using `torchrun`.

```bash
Expand Down Expand Up @@ -255,6 +276,10 @@ python -m src.runner finetune \

```

- `--pretrained_ckpt_path`: Path to a pretrained checkpoint whose **model weights only** are loaded as the starting point for finetuning. The optimizer state, learning-rate schedule, and global step are not restored — training starts fresh from step 0 with the pretrained weights. Accepts a local `.ckpt` file, a local directory containing a `.safetensors` file and `config.json`, or a HuggingFace Hub repo ID (e.g. `nvidia/codon-fm-base`).
- `--checkpoints_dir`: Directory where finetuning checkpoints are saved. Defaults to `<out_dir>/checkpoints/`. If the runner finds a `last.ckpt` here, it resumes the finetuning run (model weights, optimizer, step count) from that checkpoint instead of starting from the pretrained weights. This enables automatic resumption of interrupted finetuning jobs.
- `--resume_trainer_state`: When set, restores the full trainer state (optimizer, scheduler, step count) from the pretrained checkpoint rather than only loading model weights. Useful when continuing a pretraining run as a finetuning job.

#### Evaluation

The publicly available checkpoints can be used to launch scientific evaluation and benchmarking.
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,39 @@

import argparse
import logging
import os

import torch
from safetensors.torch import save_file as safetensors_save_file

from src.utils.load_checkpoint import load_checkpoint


logger = logging.getLogger(__name__)

ALLOWED_HYPERPARAMETER_KEYS = (
"vocab_size",
"hidden_size",
"num_hidden_layers",
"num_attention_heads",
"intermediate_size",
"hidden_act",
"hidden_dropout_prob",
"attention_probs_dropout_prob",
"initializer_range",
"layer_norm_eps",
"pad_token_id",
"position_embedding_type",
"classifier_dropout",
"rotary_theta",
"ignore_index",
"loss_type",
"lora",
"lora_alpha",
"lora_r",
"lora_dropout",
)

# PYTorch -> TE keymap
PYTORCH_TO_TE_KEYMAP = {
"model.layers.*.pre_attn_layer_norm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
Expand Down Expand Up @@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
return dst_state_dict


def filter_hyper_parameters(hyper_parameters: dict) -> dict:
"""Keep only conversion-compatible hyperparameter keys."""
return {key: value for key, value in hyper_parameters.items() if key in ALLOWED_HYPERPARAMETER_KEYS}


def main():
"""Main function."""
logging.basicConfig(level=logging.INFO)
Expand All @@ -325,6 +355,7 @@ def main():
# Load source checkpoint (automatically detects format)
logger.info(f"Loading checkpoint from {args.src}")
src_checkpoint = load_checkpoint(args.src, map_location="cpu")
src_checkpoint["hyper_parameters"] = filter_hyper_parameters(src_checkpoint["hyper_parameters"])

# Perform conversion based on direction
if args.direction == "pytorch2te":
Expand All @@ -341,11 +372,19 @@ def main():
dst_state_dict = split_qkv(converted_state_dict, src_checkpoint["hyper_parameters"])

# Prepare final checkpoint
dst_checkpoint = {"state_dict": dst_state_dict, "hyper_parameters": src_checkpoint["hyper_parameters"]}
dst_checkpoint = {
"state_dict": dst_state_dict,
"hyper_parameters": src_checkpoint["hyper_parameters"],
}

# Save the converted checkpoint in pickled format
torch.save(dst_checkpoint, args.dst)
logger.info(f"Successfully converted checkpoint from {args.src} to {args.dst}")
logger.info(f"Successfully converted checkpoint saved to {args.dst}")

# Save the state_dict in safetensors format alongside the .ckpt file
safetensors_path = os.path.splitext(args.dst)[0] + ".safetensors"
safetensors_save_file(dst_state_dict, safetensors_path)
logger.info(f"Successfully saved safetensors checkpoint to {safetensors_path}")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


# %%
import argparse
import json
import sys
from pathlib import Path
Expand All @@ -23,41 +24,52 @@
from tqdm import tqdm


sys.path.append("/workspace/codon_fm")
sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


data_path = Path("/data/ncbi/processed_unfiltered")
tax_ids_to_remove = json.load(open("/data/ncbi/taxids_to_remove.json"))
metadata = json.load(open(data_path / "metadata.json"))
tokenizer = Tokenizer()


groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
data_path / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
data_path / cm["index"]["path"], dtype=cm["index"]["dtype"], mode="r", shape=tuple(cm["index"]["shape"])
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count
def main(pretraining_processed_data_dir: Path, data_dir: Path):
"""Check codon frequency."""
tax_ids_to_remove = json.load(open(data_dir / Path("taxids_to_remove.json")))
metadata = json.load(open(pretraining_processed_data_dir / "metadata.json"))
tokenizer = Tokenizer()

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open("/data/ncbi/codon_counts_nopathogen.json", "w"))
groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
pretraining_processed_data_dir / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
pretraining_processed_data_dir / cm["index"]["path"],
dtype=cm["index"]["dtype"],
mode="r",
shape=tuple(cm["index"]["shape"]),
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open(data_dir / "codon_counts_nopathogen.json", "w"))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Check codon frequency")
parser.add_argument("--pretraining_processed_data_dir", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
args = parser.parse_args()
main(Path(args.pretraining_processed_data_dir), Path(args.data_dir))
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import argparse
import json
import os
import sys
from multiprocessing import Pool, cpu_count

import numpy as np
import polars as pl
import pyarrow.parquet as pq
from tqdm import tqdm


sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


Expand Down
Loading
Loading