-
Notifications
You must be signed in to change notification settings - Fork 3
Home
The Protein Set Transformer (PST) is a protein-based genome language model for contextualizing protein language model embeddings with genome context and subsequently producing genome embeddings from these protein embeddings.
We plan to create a pip
-installable package in the future but are having issues with a custom fork dependency.
For now, you can install the software dependencies of PST using a combination of mamba
and pip
, which should take no more than 5 minutes.
Note: you will likely need to link your git command line interface with an online github account. Follow this link for help setting up git at the command line.
# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' cpuonly pyg pytorch-scatter
mamba activate pst
# install latest updates from this repository
# best to clone the repo since you may want to run the test demo
git clone https://github.com/cody-mar10/protein_set_transformer.git
cd protein_set_transformer
pip install . #<- notice the [dot]
# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c nvidia -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' pytorch-cuda=11.8 pyg pytorch-scatter
mamba activate pst
# install latest updates from this repository
# best to clone the repo since you may want to run the test demo
git clone https://github.com/cody-mar10/protein_set_transformer.git
cd protein_set_transformer
pip install . #<- notice the [dot]
We implemented a hyperparameter tuning cross validation workflow implemented using Lightning Fabric in a base library called lightning-crossval. Part of our specific implementation for hyperparameter tuning is also implemented in the PST library.
If you want to include the optional dependendings for training a new PST, you can follow the corresponding installation steps above with the following change:
pip install .[tune]
Upon successful installation, you will have the pst
executable to train, tune, and predict. There are also other modules included as utilties that you can see using pst -h
.
You will need to first download a trained vPST model:
pst download --trained-models
This will download both vPST models into ./pstdata
, but you can change the download location using --outdir
.
You can use the test data for a test prediction run:
pst predict \
--file test/test_data.graphfmt.h5 \ # this is in the git repo
--checkpoint pstdata/pst-small_trained_model.ckpt \
--outdir test_run
The results from the above command are available at test/test_run/predictions.h5
. This test run takes fewer than 1 minute using a single CPU.
If you are unfamiliar with .h5
files, you can use pytables
(installed with PST as a dependency) to inspect .h5
files in python, or you can install hdf5
and use the h5ls
to inspect the fields in the output file.
There should be 3 fields in the prediciton file:
-
attn
which contains the per-protein attention values (shape:$N_{prot} \times N_{heads}$ ) -
ctx_ptn
which contains the contextualized PST protein embeddings (shape:$N_{prot} \times D$ ) -
genome
which contains the PST genome embeddings (shape:$N_{genome} \times D$ )- Prior to version
1.2.0
, this was calleddata
.
- Prior to version
All data associated with the initial training model training can be found here: https://doi.org/10.5061/dryad.d7wm37q8w
We have provided the README to the DRYAD data repository to render here. Additionally, we have provided a programmatic way to access the data from the command line using pst download
:
usage: pst download [-h] [--all] [--outdir PATH] [--esm-large] [--esm-small] [--vpst-large] [--vpst-small] [--genome] [--genslm]
[--trained-models] [--genome-clusters] [--protein-clusters] [--aai] [--fasta] [--host-prediction] [--no-readme]
[--supplementary-data] [--supplementary-tables]
help:
-h, --help show this help message and exit
DOWNLOAD:
--all download all files from the DRYAD repository (default: False)
--outdir PATH output directory to save files (default: ./pstdata)
EMBEDDINGS:
--esm-large download ESM2 large [t33_150M] PROTEIN embeddings for training and test viruses (esm-large_protein_embeddings.tar.gz)
(default: False)
--esm-small download ESM2 small [t6_8M] PROTEIN embeddings for training and test viruses (esm-small_protein_embeddings.tar.gz)
(default: False)
--vpst-large download vPST large PROTEIN embeddings for training and test viruses (pst-large_protein_embeddings.tar.gz) (default:
False)
--vpst-small download vPST small PROTEIN embeddings for training and test viruses (pst-small_protein_embeddings.tar.gz) (default:
False)
--genome download all genome embeddings for training and test viruses (genome_embeddings.tar.gz) (default: False)
--genslm download GenSLM ORF embeddings (genslm_protein_embeddings.tar.gz) (default: False)
TRAINED_MODELS:
--trained-models download trained vPST models (trained_models.tar.gz) (default: False)
CLUSTERS:
--genome-clusters download genome cluster labels (genome_clusters.tar.gz) (default: False)
--protein-clusters download protein cluster labels (protein_clusters.tar.gz) (default: False)
MANUSCRIPT_DATA:
--aai download intermediate files for AAI calculations in the manuscript (aai.tar.gz) (default: False)
--fasta download protein fasta files for training and test viruses (fasta.tar.gz) (default: False)
--host-prediction download all data associated with the host prediction proof of concept (host_prediction.tar.gz) (default: False)
--no-readme download the DRYAD README (README.md) (default: True)
--supplementary-data download supplementary data directly used to make the figures in the manuscript (supplementary_data.tar.gz) (default:
False)
--supplementary-tables
download supplementary tables (supplementary_tables.zip) (default: False)
For flags relating to the download of specific files, you can add as many flags as you like. For example, if you want the trained models and the raw FASTA files used to train the vPSTs downloaded into a directory called pst_models
, then you'd run this:
pst download --trained-models --fasta --outdir pst_models
The minimum input to the PST framework is a protein FASTA file, which we prefer to generate for microbes and viruses using pyrodigal. We have provided the pst embed
command to embed protein sequences using the ESM2 models.
Here is what ESM2 models are used for each vPST model:
vPST | ESM2 |
---|---|
pst-small |
esm2_t30_150M_UR50D |
pst-large |
esm2_t6_8M_UR50D |
To embed the protein sequences from a FASTA file, use the following command depending on which vPST model you are using:
### for pst-small
pst embed --input FASTAFILE.faa --esm esm2_t6_8M
### for pst-large
pst embed --input FASTAFILE.faa --esm esm2_t30_150M
pst embed
has other options to change the output directory (--outdir
), change the ESM2 model download directory (--torch-hub
), and number of CPU threads or GPU devices (--devices
).
The output of pst embed
is a single .h5
file with the field data
that stores the protein embeddings.
The protein embeddings from pst embed
are produced IN THE SAME ORDER as the sequences in the fASTA file. Thus, the following are required of the input FASTA file:
- The file must be sorted to group all proteins from the same genome together
- For the block of proteins from each genome, the proteins must be in order of their appearance in the genome.
- The FASTA headers must look like this:
scaffold_#
, wherescaffold
is the genome scaffold name and#
is the protein numerical ID relative to each scaffold. (This is the typical output fromprodigal
/pyrodigal
-- in fact, the additional information in theprodigal
-style headers is needed for the next step.)- In the event that you have multi-scaffold viruses (vMAGs, etc.), you can either manually orient the scaffolds and renumber the proteins to contiguously count from the first scaffold to the last. This is what was done with the test dataset in the manuscript.
- We provided a utility script
pst graphify
to do this if an input mapping from scaffolds to genomes is provided. See next section.
- We provided a utility script
- TODO: We are exploring a more native solution for multi-scaffold viruses that does not require an arbitrary arrangement of scaffolds that should not require changes to the model.
- In the event that you have multi-scaffold viruses (vMAGs, etc.), you can either manually orient the scaffolds and renumber the proteins to contiguously count from the first scaffold to the last. This is what was done with the test dataset in the manuscript.
Use the pst graphify
command to convert the ESM2 protein embeddings into graph format. You will need to protein FASTA file used to generate the embeddings, since the embeddings should be in the same order as the FASTA file. The FASTA file should be in prodigal format:
>scaffold_ptnid # start # stop # strand ....
If your FASTA headers have the above format, you can use the following command:
pst graphify --file EMBEDDINGSFILE.h5 --fasta-file FASTAFILE.faa
If you did not keep the extra metadata on the headers, you can alternatively provide a simple tab-delimited mapping file (--strand-file
) that maps each protein name to its strand (-1 or 1 only):
genome1_1 1
genome1_2 1
genome1_3 1
genome1_4 -1
Further, if you have multi-scaffold viruses, you can provide a tab-delimited file (--scaffold-map-file
) that maps the scaffold name to the genome name to count all proteins from the entire genome instead of each scaffold:
scaffoldA genome1
scaffoldB genome1
scaffoldC genome2
scaffoldD genome2
When installing the ptn-set-transformer
library from this repository, the model and datamodule classes are available from the pst
namespace.
The primary classes needed are the ProteinSetTransformer
, which is a subclass of the PyTorch Lightning lightning.LightningModule
. Thus, ProteinSetTransformer
has the following methods common to LightningModule
that can be overwritten:
training_step
predict_step
configure_optimizers
forward
If dramatic changes to the training setup or objective are desired, you will want to subclass the BaseProteinSetTransformer
to define these changes. See the finetuning section for more information.
ProteinSetTransformer
is a wrapper around the SetTransformer
(defined in pst.nn.models
) class which contains the encoder-decoder layers. SetTransformer
internally uses PyTorch Geometric graph formatting. Graph formatted data account for the fact that each genome encodes a different number of proteins, as graphs would contain different numbers of nodes.
The forward
method of the SetTransformer
first encodes the protein embeddings using self-attention along edges of adjacent proteins (based on the order of encoding in the genome). Each encoding layer MultiheadAttentionConv
is defined in pst.nn.layers
and uses the standard scaled dot product attention with residual connections.
Then, these contextualized protein embeddings are decoded using a MultiheadAttentionPooling
layer. This layer has a learnable d-dimensional seed vector that is used as the query along with the protein embeddings as the key and value for scaled dot product attention. The attention weights from the attention project are softmax normalized and used to weight each protein embedding for a weighted average, producing a genome embedding.
To instantiate a ProteinSetTransformer
, the Pydantic config ModelConfig
is needed which has the following schema:
### Note: these are defined in pst.nn.config
from pydantic import BaseModel
class ModelConfig(BaseModel):
in_dim: int
out_dim: int
num_heads: int
n_enc_layers: int
embed_scale: int
dropout: float
layer_dropout: float
proj_cat: bool
compile: bool
optimizer: OptimizerConfig
loss: LossConfig
augmentation: AugmentationConfig
class OptimizerConfig(BaseModel):
lr: float
weight_decay: float
betas: tuple[float, float]
warmup_steps: int
use_scheduler: bool
class AugmentationConfig(BaseModel):
sample_rate: float
class LossConfig(BaseModel):
margin: float
sample_scale: float
no_negatives_mode: NO_NEGATIVES_MODES
With a ModelConfig
instance, you can create a PST model like this:
from pst import ModelConfig
from pst import ProteinSetTransformer as PST
# ideally this is read from some external info like command line values
config = ModelConfig.default()
model = PST(config)
Alternatively, if you are starting from a pretrained model checkpoint, then you can instantiate a model that uses these pretrained weights like this:
from pst import ProteinSetTransformer as PST
checkpoint_file = "" # should be a real file path
model = PST.from_pretrained(checkpoint_file)
The model config is internally created by the .from_pretrained
class method since the attributes are stored in the trained model's checkpoint.
The forward
method (and any other data-facing method) of ProteinSetTransformer
require graph-formatted data. We use the standards set by PyTorch Geometric to model our pst.GenomeGraph
after. In this setup, each genome is viewed as a graph (pst.GenomeGraph
), and proteins are nodes in this graph. The protein nodes are connected if they are adjacently encoded in the genome (subject to hyperparameters). Each protein node is represented by its corresponding protein embedding (x
).
The attributes of a GenomeGraph
object look like this:
class GenomeGraph:
x: torch.Tensor # shape: [N, d] <- protein embeddings
edge_index: torch.Tensor # shape: [2, E] <- define protein-protein connections
num_proteins: int # <- number of proteins encoded by this genome, ie number of nodes
class_id: int # <- optional class ID for genome/graph level class
strand: torch.Tensor # shape: [N] <- strand of each protein
pos: torch.Tensor # shape: [N, 1] <- integer tensor that counts from 0 to N-1
y: torch.Tensor | None # <- optional label tensor for this genome graph
For a collection of GenomeGraph
s, such as in a minibatch, each protein embedding tensor from each genome are stacked. We use an index pointer (ptr
) to keep track of the start and stop positions for the proteins belonging to each genome. This allows efficient random access.
The pst.GenomeDataset
handles creating batches of GenomeGraph
s. However, the specific batch object is a GenomeDataBatch
with the following fields:
class GenomeGraphBatch:
x: torch.Tensor # shape: [N, d] <- stacked protein embeddings
edge_index: torch.Tensor # shape: [2, E] <- define protein-protein connections
num_proteins: torch.Tensor # <- number of proteins encoded by each genome
class_id: torch.Tensor # <- optional class ID for genome/graph level class for each genome
strand: torch.Tensor # shape: [N] <- strand of each protein in all genomes
pos: torch.Tensor # shape: [N, 1] <- integer tensor that counts from 0 to N-1 for each genome
y: torch.Tensor | None # <- optional label tensor for each genome
### new fields
ptr: torch.Tensor # shape: [num genomes + 1] <- index pointer to compute offsets
batch: torch.Tensor # shape: [N] <- assigns each protein node to a unique genome ID
The protein-protein edges (edge_index
) for each genome graph are computed on the fly upon instantiating a GenomeDataset
. This is because there are 2 hyperparameters that control the connectivity of the genome graphs: chunk_size
and threshold
. chunk_size
determines the total number of nodes that can belong to subgraphs, while threshold
determines the maximum distance in number of proteins that each protein will be connected to. A threshold
value of -1 indicates no maximum distance and leads to each subgraph being fully connected. The value of chunk_size
determines the number of subgraphs for each genome:
This chunking is used for memory efficiency but may also reflect the real evolutionary pressures of genes that are encoded near each other. This is also support for sparsifying these subgraphs by changing the value to the --threshold
command line option, which gets sent to the DataConfig.threshold
parameter. This was not used for pretraining the original vPSTs, but this option is available.
To store data that fits the above data models, .h5
files are required with the following fields:
dataset.h5:
- data: stores the stacked protein embeddings (maps to x)
- ptr: offsets needed to randomly access all proteins from each genome from the data field
- sizes: number of proteins for each genome
- strand: protein encoding strand for each protein
Note that class_id
is optional. If not provided in the .h5
file, all genomes will default to the same class (which will probably not be used pending the model's training loop).
If you follow the information provided above for generating protein embeddings from FASTA files and converting these embeddings to the required graph format, your file format should be taken care of.
Our model class ProteinSetTransformer
is a lightning.LightningModule
subclass from PyTorch Lightning.
We make use of Pydantic schema models as the config for our model, but you can load a pretrained model from a PyTorch checkpoint like this:
from pst.nn.modules import ProteinSetTransformer as PST
ckptfile = "pst-small_trained_model.ckpt"
model = PST.from_pretrained(ckptfile)
Similarly, to load a pretained GenomeDatamodule
, you can do the following. Note: the chunk_size
for each genome graph were tuned in the pretrained vPSTs.
from pst.data.modules import GenomeDatamodule
ckptfile = "pst-small_trained_model.ckpt"
new_data_file = "new_data.graphfmt.h5
datamodule = GenomeDatamodule.from_pretrained(ckptfile, new_data_file)
You still need to give the file location for the graph-formatted .h5
file.
TODO: describe the pst.GenomeDataModule
class. <- this might be better in a pytorch lightning integration section
TODO: We are adding a finetuning command line mode for this since this is basically the same as training a model but starting from a pretrained model.
We have provided a code example in with this repository at examples/finetuning.ipynb. This covers the case of a new objective that focuses on either genome- or protein-level tasks.
In brief, you need to subclass either pst.BaseProteinSetTransformer
for genome tasks (or dual genome/protein tasks) OR pst.BaseProteinSetTransformerEncoder
for protein-only tasks.
The subclassed models must define the following methods:
-
setup_objective
- should return a callable that can be used to compute the loss- If the loss function requires a tunable state that, such as the margin and scaling factor of triplet loss, then a custom loss and model config can be defined using
pst.BaseLossConfig
andpst.BaseModelConfig
, respectively. - The custom loss config must be used to override the loss field in the custom model config.
- If the loss function requires a tunable state that, such as the margin and scaling factor of triplet loss, then a custom loss and model config can be defined using
-
forward
- which defines the model's forward pass, including the data handling and loss computation
Optionally but probably recommended, you will need to update the custom model's __init__
method to define new trainable layers needed for the new model objective.
All functionality to do this is embedded in the pst train
command line mode.
To tune hyperparameters, there is also the pst tune
command line mode that leverages a custom library we built on Lightning Fabric called Lightning CrossVal. This library enables epoch synchronized hyperparameter tuning through cross validation. The cross validation strategy can be defined using the lightning-cv
framework.
TODO: better description of the training and hyper parameter tuning process.