Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8682051
Updating dashboard and script logic
jwilber Mar 15, 2026
87921aa
fix vocab build to match logits.shape[1]
jwilber Mar 15, 2026
4753cb3
update 650m script
jwilber Mar 15, 2026
401ec4f
dataloader fix for multirank
jwilber Mar 16, 2026
bb5c792
update dashboard
jwilber Mar 16, 2026
f5da779
better f1 values and logging
jwilber Mar 16, 2026
abee3c1
add codonfm dir
jwilber Mar 16, 2026
01435bd
don't hardcode swissprot annotation score
jwilber Mar 17, 2026
02ab8f5
save checkpoint after save
jwilber Mar 17, 2026
dad85b3
add go enrichment eval for codonfm
jwilber Mar 17, 2026
9f77569
add esm2 sweep
jwilber Mar 17, 2026
74d2674
fix streaming on-complete bug
jwilber Mar 17, 2026
ef1e187
count parquet count
jwilber Mar 17, 2026
0ce258e
remove alphafold molstar logic from codonfm script
jwilber Mar 18, 2026
d9ce6d8
add go eval
jwilber Mar 18, 2026
81f670e
update training cmnd
jwilber Mar 18, 2026
6bdbffa
add gradient accumulation
jwilber Mar 18, 2026
a4f25fc
remove steering
jwilber Mar 18, 2026
0030a4f
remove steering modules
jwilber Mar 18, 2026
f2e255d
add grad acc args
jwilber Mar 18, 2026
584a600
update codonfm dashboard
jwilber Mar 18, 2026
349905e
revert change
jwilber Mar 18, 2026
7c55e96
update esm2 dashboard
jwilber Mar 18, 2026
ba5df03
remove steering mention for now
jwilber Mar 18, 2026
e384631
add obo-dir fix
jwilber Mar 19, 2026
97a41e9
lint
jwilber Mar 19, 2026
ef4fa1a
Merge branch 'main' into jwilber/esm2-sae-dashboard-update
jwilber Mar 19, 2026
d0847df
lint dashboard
jwilber Mar 19, 2026
d541c25
lint + precommit
jwilber Mar 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.cache/
outputs/
wandb/

# Generated dashboard data
**/public/
**/dist/
**/feature_analysis.json
**/feature_labels.json
**/vocab_logits.json
**/node_modules/
dash/
dash2/
dash108k/
dash_438k/
dash_438k_auto/
dash_438k_clinvar/
dash_ef64/
gtc_dash/
nndash/
nndash2/
olddash/
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/bin/bash
set -e

# CodonFM Encodon-1B SAE Pipeline

MODEL_PATH=checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1/model.safetensors
CSV_PATH=/data/jwilber/codonfm/data/sample_108k.csv
LAYER=16
NUM_SEQUENCES=10000
OUTPUT_DIR=./outputs/1b_layer16

echo "============================================================"
echo "STEP 1: Extract activations from Encodon-1B"
echo "============================================================"

torchrun --nproc_per_node=4 scripts/extract.py \
--csv-path $CSV_PATH \
--model-path $MODEL_PATH \
--layer $LAYER \
--num-sequences $NUM_SEQUENCES \
--batch-size 8 \
--context-length 2048 \
--shard-size 100000 \
--output .cache/activations/primates_${NUM_SEQUENCES}_1b_layer${LAYER}

echo ""
echo "============================================================"
echo "STEP 2: Train SAE on cached activations"
echo "============================================================"

torchrun --nproc_per_node=4 scripts/train.py \
--cache-dir .cache/activations/primates_${NUM_SEQUENCES}_1b_layer${LAYER} \
--model-path $MODEL_PATH \
--layer $LAYER \
--model-type topk \
--expansion-factor 16 \
--top-k 32 \
--auxk 512 \
--auxk-coef 0.03125 \
--dead-tokens-threshold 500000 \
--n-epochs 40 \
--batch-size 4096 \
--lr 3e-4 \
--log-interval 50 \
--dp-size 4 \
--seed 42 \
--wandb \
--wandb-project sae_codonfm_recipe \
--wandb-run-name "1b_layer${LAYER}_ef16_k32" \
--output-dir ${OUTPUT_DIR} \
--checkpoint-dir ${OUTPUT_DIR}/checkpoints

echo ""
echo "============================================================"
echo "STEP 3: Analyze features (vocab logits + codon annotations)"
echo "============================================================"

python scripts/analyze.py \
--checkpoint ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt \
--model-path $MODEL_PATH \
--csv-path $CSV_PATH \
--layer $LAYER \
--num-sequences $NUM_SEQUENCES \
--batch-size 8 \
--output-dir ${OUTPUT_DIR}/analysis \
--dashboard-dir ${OUTPUT_DIR}/dashboard

echo ""
echo "============================================================"
echo "STEP 4: Build dashboard"
echo "============================================================"

python scripts/dashboard.py \
--checkpoint ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt \
--model-path $MODEL_PATH \
--csv-path $CSV_PATH \
--layer $LAYER \
--num-sequences $NUM_SEQUENCES \
--batch-size 8 \
--n-examples 6 \
--umap-n-neighbors 15 \
--umap-min-dist 0.1 \
--hdbscan-min-cluster-size 20 \
--output-dir ${OUTPUT_DIR}/dashboard

echo ""
echo "============================================================"
echo "DONE — Dashboard output: ${OUTPUT_DIR}/dashboard"
echo "============================================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
set -e

# CodonFM Encodon-1B SwissProt F1 Evaluation Pipeline
# Evaluates whether CodoNFM SAE features align with protein-level SwissProt annotations

MODEL_PATH=checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1/model.safetensors
LAYER=16
OUTPUT_DIR=./outputs/1b_layer16

echo "============================================================"
echo "STEP 1: Download SwissProt proteins with CDS sequences"
echo "============================================================"

python scripts/download_codonfm_swissprot.py \
--output-dir ./data/codonfm_swissprot \
--max-proteins 8000 \
--max-length 512 \
--annotation-score 5 \
--workers 8

echo ""
echo "============================================================"
echo "STEP 2: F1 evaluation against SwissProt annotations"
echo "============================================================"

python scripts/eval_swissprot_f1.py \
--checkpoint ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt \
--model-path $MODEL_PATH \
--layer $LAYER \
--batch-size 8 \
--context-length 2048 \
--swissprot-tsv ./data/codonfm_swissprot/codonfm_swissprot.tsv.gz \
--f1-max-proteins 8000 \
--f1-min-positives 10 \
--f1-threshold 0.3 \
--normalization-n-proteins 2000 \
--output-dir ${OUTPUT_DIR}/swissprot_eval

echo ""
echo "============================================================"
echo "DONE — SwissProt F1 results: ${OUTPUT_DIR}/swissprot_eval"
echo "============================================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# CodonFM SAE Recipe

Train and analyze sparse autoencoders on [CodonFM](https://huggingface.co/nvidia/NV-CodonFM-Encodon-1B-v1) Encodon codon language models. The pipeline extracts residual stream activations, trains a TopK SAE, evaluates reconstruction quality, and optionally generates an interactive feature dashboard.

## Pipeline

```
Extract activations -> Train SAE -> Evaluate -> Analyze (optional) -> Dashboard (optional)
```

**Extract** runs the Encodon model over DNA coding sequences, saving per-codon hidden states from a target layer to sharded Parquet files. **Train** fits a TopK SAE (8x expansion, top-32 sparsity by default) on those activations. **Evaluate** measures loss recovered by comparing model logits with and without the SAE bottleneck. **Analyze** computes per-feature interpretability annotations (codon usage bias, amino acid identity, wobble position, CpG content) and optionally generates LLM-based feature labels. **Dashboard** builds UMAP embeddings and exports data for a React-based interactive feature explorer.

## Prerequisites

1. Encodon checkpoint (`.safetensors` or `.ckpt` with accompanying `config.json`):

```bash
huggingface-cli download nvidia/NV-CodonFM-Encodon-1B-v1 --local-dir ./checkpoints/encodon_1b
```

2. DNA sequence data as a CSV with a coding sequence column (`cds`, `seq`, or `sequence` -- auto-detected).

3. Install dependencies:

```bash
# From repo root (UV workspace)
uv sync
```

## Quick Start

```bash
# Full pipeline: extract -> train -> eval
python run.py model=1b csv_path=path/to/Primates.csv

# Skip extraction if activations are already cached
python run.py model=1b csv_path=path/to/data.csv steps.extract=false

# Smoke test
python run.py model=1b csv_path=path/to/data.csv num_sequences=100 train.n_epochs=1 nproc=1 dp_size=1
```

## Step-by-Step

### 1. Extract Activations

```bash
# Single GPU
python scripts/extract.py \
--csv-path path/to/Primates.csv \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 \
--num-sequences 50000 \
--output .cache/activations/encodon_1b_layer-2

# Multi-GPU
torchrun --nproc_per_node=4 scripts/extract.py \
--csv-path path/to/Primates.csv \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 \
--output .cache/activations/encodon_1b_layer-2
```

Outputs sharded Parquet files + `metadata.json` to the cache directory. CLS and SEP tokens are stripped; only codon-position activations are saved.

### 2. Train SAE

```bash
python scripts/train.py \
--cache-dir .cache/activations/encodon_1b_layer-2 \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 \
--expansion-factor 8 --top-k 32 \
--batch-size 4096 --n-epochs 3 \
--output-dir ./outputs/encodon_1b

# Multi-GPU
torchrun --nproc_per_node=4 scripts/train.py \
--cache-dir .cache/activations/encodon_1b_layer-2 \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 --dp-size 4 \
--expansion-factor 8 --top-k 32 \
--batch-size 4096 --n-epochs 3 \
--output-dir ./outputs/encodon_1b
```

Saves checkpoint to `./outputs/encodon_1b/checkpoints/checkpoint_final.pt`.

### 3. Evaluate

```bash
python scripts/eval.py \
--checkpoint ./outputs/encodon_1b/checkpoints/checkpoint_final.pt \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 --top-k 32 \
--csv-path path/to/data.csv \
--output-dir ./outputs/encodon_1b/eval
```

### 4. Analyze Features (optional)

```bash
python scripts/analyze.py \
--checkpoint ./outputs/encodon_1b/checkpoints/checkpoint_final.pt \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 --top-k 32 \
--csv-path path/to/Primates.csv \
--output-dir ./outputs/encodon_1b/analysis \
--auto-interp --max-auto-interp-features 500
```

Produces `vocab_logits.json`, `feature_analysis.json`, and `feature_labels.json`.

### 5. Dashboard (optional)

```bash
# Generate dashboard data
python scripts/dashboard.py \
--checkpoint ./outputs/encodon_1b/checkpoints/checkpoint_final.pt \
--model-path path/to/encodon_1b/NV-CodonFM-Encodon-1B-v1.safetensors \
--layer -2 --top-k 32 \
--csv-path path/to/Primates.csv \
--output-dir ./outputs/encodon_1b/dashboard

# Launch web UI
python scripts/launch_dashboard.py --data-dir ./outputs/encodon_1b/dashboard
```

## Model Sizes

| Model | Params | Layers | Hidden Dim | Batch Size | Config |
| ------------ | ------ | ------ | ---------- | ---------- | ------------ |
| Encodon 80M | 80M | 6 | 1024 | 32 | `model=80m` |
| Encodon 600M | 600M | 12 | 2048 | 16 | `model=600m` |
| Encodon 1B | 1B | 18 | 2048 | 8 | `model=1b` |
| Encodon 5B | 5B | 24 | 4096 | 2 | `model=5b` |

## Configuration

Hydra configs live in `run_configs/`. The base config (`config.yaml`) sets defaults for all steps. Model-specific configs in `run_configs/model/` override `model_path`, `run_name`, `num_sequences`, and `batch_size`.

Override any parameter on the command line:

```bash
python run.py model=1b csv_path=data.csv train.n_epochs=5 train.lr=1e-4 nproc=8
```

Key training defaults: `expansion_factor=8`, `top_k=32`, `lr=3e-4`, `n_epochs=3`, `batch_size=4096`, `layer=-2`.

## Project Structure

```
recipes/codonfm/
run.py Hydra pipeline orchestrator
run_configs/ Hydra configs (config.yaml, model/*.yaml)
scripts/
extract.py Extract layer activations (multi-GPU)
train.py Train TopK SAE (multi-GPU)
eval.py Loss recovered evaluation
analyze.py Feature interpretability annotations
dashboard.py UMAP + dashboard data export
launch_dashboard.py Serve interactive web UI
mutation_features.py Mutation-site feature analysis
src/codonfm_sae/ Recipe-specific code (CSV loader, eval)
codon-fm/ CodonFM model code (tokenizer, inference, models)
codon_dashboard/ React/Vite interactive dashboard
notebooks/ Jupyter notebooks (UMAP exploration)
```

## Data Format

CSV with a DNA coding sequence column. The loader auto-detects columns named `cds`, `seq`, or `sequence`. Each sequence should be a string of nucleotides whose length is divisible by 3 (codons). The tokenizer splits into 3-mer codons from a 69-token vocabulary (5 special + 64 DNA codons).
Loading
Loading