- PlantCAD overview
- Quick Start
- Model summary
- Prerequisites and system requirements
- Installation
- Basic Usage
- Advanced Usage
- Development and Training
- Citation
PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.
New to PlantCAD? Try our Google Colab demo - no installation required!
For local usage: See installation instructions below, then use notebooks/examples.ipynb
to get started.
Pre-trained PlantCAD models have been uploaded to HuggingFace 🤗. Here's the summary of four PlantCAD models with different parameter sizes.
Model | Sequence Length | Model Size | Embedding Size |
---|---|---|---|
PlantCaduceus_l20 | 512bp | 20M | 384 |
PlantCaduceus_l24 | 512bp | 40M | 512 |
PlantCaduceus_l28 | 512bp | 128M | 768 |
PlantCaduceus_l32 | 512bp | 225M | 1024 |
Model Selection Guide:
- PlantCaduceus_l20: Good for testing and quick analysis
- PlantCaduceus_l32: Recommended for research and production (best performance)
For Google Colab: Just a Google account - GPU runtime recommended (free tier available)
For Local Installation: GPU recommended for reasonable performance. Dependencies will be installed automatically during setup.
No installation required! Just open our PlantCAD Google Colab notebook and start analyzing your data.
Setup steps:
- Open the Colab link
- Important: Set runtime to GPU (
Runtime
→Change runtime type
→Hardware accelerator: GPU
) - Run the cells to install dependencies
- Upload your data or use the provided examples
Step 1: Create conda environment
# Clone the repository (if you haven't already)
git clone https://github.com/kuleshov-group/PlantCaduceus.git
cd PlantCaduceus
# Create and activate environment
conda env create -f env/environment.yml
conda activate PlantCAD
Step 2: Install Python packages
pip install -r env/requirements.txt --no-build-isolation
Step 3: Verify installation
# Test core dependencies
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForMaskedLM
# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l32')
model = AutoModelForMaskedLM.from_pretrained('kuleshov-group/PlantCaduceus_l32', trust_remote_code=True)
device = 'cuda:0'
model.to(device)
print("✅ Installation successful!")
Alternative: pip-only installation If you prefer pip-only installation, see issue #10 for community solutions.
mamba_ssm issues (most common):
# If mamba_ssm import fails, reinstall with:
pip uninstall mamba-ssm
pip install mamba-ssm==2.2.0 --no-build-isolation
CUDA/GPU issues:
- Verify CUDA installation:
nvidia-smi
- Check PyTorch CUDA support:
python -c "import torch; print(torch.cuda.is_available())"
- For CPU-only usage: Models will work but be significantly slower
The easiest way to start is with our example notebook: notebooks/examples.ipynb
Quick example - Get sequence embeddings:
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
device = 'cuda:0'
# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l32')
model = AutoModelForMaskedLM.from_pretrained('kuleshov-group/PlantCaduceus_l32', trust_remote_code=True)
model.to(device)
# Example plant DNA sequence (512bp max)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
device = 'cuda:0'
# Get embeddings
encoding = tokenizer.encode_plus(
sequence,
return_tensors="pt",
return_attention_mask=False,
return_token_type_ids=False
)
input_ids = encoding["input_ids"].to(device)
with torch.inference_mode():
outputs = model(input_ids=input_ids, output_hidden_states=True)
embeddings = outputs.hidden_states[-1]
print(f"Embedding shape: {embeddings.shape}") # [batch_size, seq_len, embedding_dim]
embeddings = embeddings.to(torch.float32).cpu().numpy()
# Given that PlantCaduceus has bi-directionality and reverse complement equivariance, so the first half of embedding is for forward sequences and the sencond half is for reverse complemented sequences, we need to average the embeddings before working on downstream classifier
hidden_size = embeddings.shape[-1] // 2
forward = embeddings[..., 0:hidden_size]
reverse = embeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)
Estimate the functional impact of genetic variants using PlantCAD's log-likelihood scores.
Input format options:
- VCF files (recommended): Standard variant format with reference genome
- TSV files: Pre-processed sequences with variant information
Basic usage with VCF:
# Download example reference genome
wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz
# Run zero-shot scoring
python src/zero_shot_score.py \
-input-vcf examples/example_maize_snp.vcf \
-input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
-output scored_variants.vcf \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0'
Expected output:
- Scored VCF file with PlantCAD scores in the INFO field
- Scores represent log-likelihood ratios between reference and alternative allelesLow negative scores indicate more likely deleterious mutations
Convert VCF to table format (optional, for easier processing):
bash src/format_VCF.sh \
examples/example_maize_snp.vcf \
Zm-B73-REFERENCE-NAM-5.0.fa \
formatted_variants.tsv
Use table format directly:
python src/zero_shot_score.py \
-input-table formatted_variants.tsv \
-output results.tsv \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0' \
-outBED # Optional: output in BED format
For large-scale simulation and analysis of genetic variants, we provide a comprehensive in-silico mutagenesis pipeline. See pipelines/in-silico-mutagenesis/README.md for detailed instructions.
Train custom classifiers on top of PlantCAD embeddings for specific annotation tasks (e.g., TIS, TTS, splice sites).
Purpose: Fine-tune prediction performance for specific annotation tasks using supervised learning.
Data format: Training data should follow the format used in our cross-species annotation dataset.
python src/train_XGBoost.py \
-train train.tsv \
-valid valid.tsv \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-output ./output \
-device 'cuda:0'
Expected outputs:
- Trained XGBoost classifier (
.json
file) - Performance metrics on validation/test sets
- Feature importance analysis
We provide pre-trained XGBoost classifiers for common annotation tasks in the classifiers
directory.
Available classifiers:
- TIS (Translation Initiation Sites)
- TTS (Translation Termination Sites)
- Splice donor/acceptor sites
python src/predict_XGBoost.py \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \
-device 'cuda:0' \
-output ./output
Expected output: Predictions with confidence scores for each sequence in your test data.
For advanced users who want to pre-train PlantCAD models from scratch or fine-tune on custom datasets.
Requirements:
- Large computational resources (multi-GPU recommended)
- WandB account for experiment tracking
- Custom genomic dataset in HuggingFace format
Basic pre-training command:
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py \
--do_train \
--report_to wandb \
--prediction_loss_only True \
--remove_unused_columns False \
--dataset_name 'kuleshov-group/Angiosperm_16_genomes' \
--soft_masked_loss_weight_train 0.1 \
--soft_masked_loss_weight_evaluation 0.0 \
--weight_decay 0.01 \
--optim adamw_torch \
--dataloader_num_workers 16 \
--preprocessing_num_workers 16 \
--seed 32 \
--save_strategy steps \
--save_steps 1000 \
--evaluation_strategy steps \
--eval_steps 1000 \
--logging_steps 10 \
--max_steps 120000 \
--warmup_steps 1000 \
--save_total_limit 20 \
--learning_rate 2E-4 \
--lr_scheduler_type constant_with_warmup \
--run_name test \
--overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 4 \
--tokenizer_name 'kuleshov-group/PlantCaduceus_l20' \
--config_name 'kuleshov-group/PlantCaduceus_l20'
Key parameters:
dataset_name
: Your custom dataset or use our Angiosperm datasetmax_steps
: Total training steps (adjust based on dataset size)learning_rate
: 2E-4 works well for most cases- Batch sizes: Adjust based on your GPU memory
The inference speed is highly dependent on the model size and GPU type. Performance with 5,000 SNPs:
Model | H100 | A100 | A6000 | 3090 | A5000 | A40 | 2080 |
---|---|---|---|---|---|---|---|
PlantCaduceus_l20 | 16s | 19s | 24s | 25s | 25s | 26s | 44s |
PlantCaduceus_l24 | 21s | 27s | 35s | 37s | 42s | 38s | 71s |
PlantCaduceus_l28 | 31s | 43s | 62s | 69s | 77s | 67s | 137s |
PlantCaduceus_l32 | 47s | 66s | 94s | 116s | 130s | 107s | 232s |
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py --do_train
--report_to wandb --prediction_loss_only True --remove_unused_columns False --dataset_name 'kuleshov-group/Angiosperm_16_genomes' --soft_masked_loss_weight_train 0.1 --soft_masked_loss_weight_evaluation 0.0 \
--weight_decay 0.01 --optim adamw_torch \
--dataloader_num_workers 16 --preprocessing_num_workers 16 --seed 32 \
--save_strategy steps --save_steps 1000 --evaluation_strategy steps --eval_steps 1000 --logging_steps 10 \
--max_steps 120000 --warmup_steps 1000 \
--save_total_limit 20 --learning_rate 2E-4 --lr_scheduler_type constant_with_warmup \
--run_name test --overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --gradient_accumulation_steps 4 --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' --config_name 'kuleshov-group/PlantCaduceus_l20'
Zhai, J., Gokaslan, A., Schiff, Y., Berthel, A., Liu, Z. Y., Lai, W. L., Miller, Z. R., Scheben, A., Stitzer, M. C., Romay, M. C., Buckler, E. S., & Kuleshov, V. (2025). Cross-species modeling of plant genomes at single nucleotide resolution using a pretrained DNA language model. Proceedings of the National Academy of Sciences, 122(24), e2421738122. https://doi.org/10.1073/pnas.2421738122