PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.
conda env create -f env/environment.yml
conda activate PlantCAD
pip install -r env/requirements.txt --no-build-isolation
import torch
from mamba_ssm import Mamba
- If not, please re-install mamba_ssm by running the following command:
pip uninstall mamba-ssm
pip install mamba-ssm==2.2.0 --no-build-isolation
The example notebook to use PlantCAD to get embeddings and logits score is available in the notebooks/examples.ipynb
directory.
Pre-trained PlantCAD models have been uploaded to Hugging Face. The available models are:
- PlantCaduceus_l20: kuleshov-group/PlantCaduceus_l20
- Trained on sequences of length 512bp, with a model size of 256 and 20 layers.
- PlantCaduceus_l24: kuleshov-group/PlantCaduceus_l24
- Trained on sequences of length 512bp, with a model size of 256 and 24 layers.
- PlantCaduceus_l28: kuleshov-group/PlantCaduceus_l28
- Trained on sequences of length 512bp, with a model size of 256 and 28 layers.
- PlantCaduceus_l32: kuleshov-group/PlantCaduceus_l32
- Trained on sequences of length 512bp, with a model size of 256 and 32 layers.
Here's an example notebook to show how to run PlantCAD on google colab: PlantCAD google colab
We trained an XGBoost model on top of the PlantCAD embedding for each task to evaluate its performance. The script is available in the src
directory. The script takes the following arguments:
python src/train_XGBoost.py \
-train train.tsv \ # training data, data format: https://huggingface.co/datasets/kuleshov-group/cross-species-single-nucleotide-annotation/tree/main/TIS
-valid valid.tsv \ # validation data, the same format as the training data
-test test_rice.tsv \ # test data (optional), the same format as the training data
-model 'kuleshov-group/PlantCaduceus_l20' \ # pre-trained model name
-output ./output \ # output directory
-device 'cuda:0' # GPU device to dump embeddings
The trained XGBoost classifiers in the paper are available here, the following script is used for prediction with XGBoost model
python src/predict_XGBoost.py \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \ # pre-trained model
-classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \ # the trained XGBoost classifier
-device 'cuda:0' \ # GPU device to dump embeddings
-output ./output # output directory
We used the log-likelihood difference between the reference and the alternative alleles to estimate the mutation effect. The script is available in the src
directory. The script takes the following arguments:
python src/zero_shot_score.py \
-input examples/example_snp.tsv \
-output output.tsv \
-model 'kuleshov-group/PlantCaduceus_l32' \ # pre-trained model name
-device 'cuda:1' # GPU device to dump embeddings
Note: we would highly recommend using the largest model (PlantCaduceus_l32) for the zero-shot score estimation.
- We also provide a pipeline to generate input files from VCF and genome FASTA files
# prepare bed
inputVCF="input.vcf"
genomeFA="genome.fa"
output="snp_info.tsv" # this could be the input file of the zero_shot_score.py code
grep -v '#' ${inputVCF} | awk -v OFS="\t" '{print $1,$2-256,$2+256}' > ${inputVCF}.bed
bedtools getfasta -tab -fi ${genomeFA} -bed ${inputVCF}.bed -fo ${inputVCF}.seq.tsv
awk -v OFS="\t" '{print $1,$2-256, $2+256,$2,$4,$5}' ${inputVCF} | paste - <(cut -f2 ${inputVCF}.seq.tsv) > ${output}_tmp
# add header
echo -e "chr\tstart\tend\tpos\tref\talt\tsequences" > ${output}
cat ${output}_tmp >> ${output}
rm ${inputVCF}.bed ${inputVCF}.seq.tsv ${output}_tmp
The inference speed is highly dependent on the model size and GPU type, we tested on some commonly used GPUs. With 5,000 SNPs, the inference speed is as follows:
Model | GPU | Time |
---|---|---|
PlantCaduceus_l20 | H100 | 16s |
A100 | 19s | |
A6000 | 24s | |
3090 | 25s | |
A5000 | 25s | |
A40 | 26s | |
2080 | 44s | |
PlantCaduceus_l24 | H100 | 21s |
A100 | 27s | |
A6000 | 35s | |
3090 | 37s | |
A40 | 38s | |
A5000 | 42s | |
2080 | 71s | |
PlantCaduceus_l28 | H100 | 31s |
A100 | 43s | |
A6000 | 62s | |
A40 | 67s | |
3090 | 69s | |
A5000 | 77s | |
2080 | 137s | |
PlantCaduceus_l32 | H100 | 47s |
A100 | 66s | |
A6000 | 94s | |
A40 | 107s | |
3090 | 116s | |
A5000 | 130s | |
2080 | 232s |
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py --do_train
--report_to wandb --prediction_loss_only True --remove_unused_columns False --dataset_name 'kuleshov-group/Angiosperm_16_genomes' --soft_masked_loss_weight_train 0.1 --soft_masked_loss_weight_evaluation 0.0 \
--weight_decay 0.01 --optim adamw_torch \
--dataloader_num_workers 16 --preprocessing_num_workers 16 --seed 32 \
--save_strategy steps --save_steps 1000 --evaluation_strategy steps --eval_steps 1000 --logging_steps 10 \
--max_steps 120000 --warmup_steps 1000 \
--save_total_limit 20 --learning_rate 2E-4 --lr_scheduler_type constant_with_warmup \
--run_name test --overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --gradient_accumulation_steps 4 --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' --config_name 'kuleshov-group/PlantCaduceus_l20'
@article {Zhai2024.06.04.596709,
author = {Zhai, Jingjing and Gokaslan, Aaron and Schiff, Yair and Berthel, Ana and Liu, Zong-Yan and Miller, Zachary R and Scheben, Armin and Stitzer, Michelle C and Romay, Cinta and Buckler, Edward S. and Kuleshov, Volodymyr},
title = {Cross-species plant genomes modeling at single nucleotide resolution using a pre-trained DNA language model},
elocation-id = {2024.06.04.596709},
year = {2024},
doi = {10.1101/2024.06.04.596709},
URL = {https://www.biorxiv.org/content/early/2024/06/05/2024.06.04.596709},
eprint = {https://www.biorxiv.org/content/early/2024/06/05/2024.06.04.596709.full.pdf},
journal = {bioRxiv}
}