Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
660a9b5
Add dataloaders for csv and hugging face
NennoMP Aug 10, 2025
e1d8a6a
Update .gitignore
NennoMP Aug 10, 2025
bd7cdf4
Update .gitignore
NennoMP Aug 10, 2025
c7366bf
Add first version of solvers
NennoMP Aug 11, 2025
a5c3e35
Keep updated with #63
NennoMP Aug 11, 2025
1ab8c11
Add data preprocessing for RNA pretraining
NennoMP Aug 16, 2025
a0ebbd6
Run pre-commit
NennoMP Aug 16, 2025
aa65fd0
Complete data loading for pretraining, extend example notebook
NennoMP Aug 17, 2025
84da30d
Run pre-commit
NennoMP Aug 17, 2025
5c611da
Make all modules private, ready for lightning studio
NennoMP Aug 18, 2025
152a79e
Add pretrained weights loading for AptaTrans; add tests
NennoMP Aug 18, 2025
5c4112a
Add pretrained weights loading for AptaTrans; add tests
NennoMP Aug 18, 2025
26ba9e6
Add pretrained weights loading for AptaTrans; add tests
NennoMP Aug 18, 2025
ac6f7f5
Merge branch 'main' into feature/49-aptatrans-training-schema
NennoMP Aug 24, 2025
cd90141
Merge
NennoMP Aug 24, 2025
fcf104c
Merge branch 'feature/49-aptatrans-training-schema' of https://github…
NennoMP Aug 24, 2025
6722e97
Move utils import from root to submodules
NennoMP Aug 24, 2025
f11e158
Run pre-commit
NennoMP Aug 24, 2025
646f958
Reorganize aptatrans utilities
NennoMP Aug 24, 2025
ed55168
Run pre-commit
NennoMP Aug 24, 2025
690acaa
Update
NennoMP Aug 25, 2025
9f46122
Remove duplicate model.py and pipeline.py
NennoMP Sep 1, 2025
5318060
Merge branch 'main' into feature/49-aptatrans-training-schema
NennoMP Sep 4, 2025
cb21c89
Revert renaming of utils
NennoMP Sep 11, 2025
3aa03df
Merge branch 'main' into feature/49-aptatrans-training-schema
NennoMP Sep 11, 2025
0453cff
Run pre-commit
NennoMP Sep 11, 2025
0bff5a9
Fix bug in tests
NennoMP Sep 11, 2025
7d8353b
Update rna2vec
NennoMP Sep 11, 2025
443568e
Fix bug in docstring
NennoMP Sep 11, 2025
1febecf
Resolve merge conflicts
NennoMP Sep 25, 2025
12e1397
Add training lightning wrapper for AptaTrans' encoders
NennoMP Sep 28, 2025
1bd6978
Fix a few bugs
NennoMP Sep 29, 2025
8b720bd
Update rna2vec to handle short/long sequences (padding/truncation)
NennoMP Sep 29, 2025
4e2a16c
Fix AptaTransPipeline docstrings, rename predict method
NennoMP Sep 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,11 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore

# Dataset files
# Exclude specific (too big) dataset files
pyaptamer/datasets/data/bpRNA-shin2023.csv
pyaptamer/datasets/data/proteins-shin2023.csv

# Model weights files
pyaptamer/aptatrans/weights/pretrained.pt
7 changes: 6 additions & 1 deletion pyaptamer/aptatrans/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
(API) and recommending candidate aptamers for a given target protein.
"""

__author__ = ["nennomp"]
__all__ = [
"AptaTrans",
"AptaTransLightning",
"AptaTransEncoderLightning",
"AptaTransPipeline",
"EncoderPredictorConfig",
]

from pyaptamer.aptatrans._model import AptaTrans
from pyaptamer.aptatrans._model_lightning import AptaTransLightning
from pyaptamer.aptatrans._model_lightning import (
AptaTransEncoderLightning,
AptaTransLightning,
)
from pyaptamer.aptatrans._pipeline import AptaTransPipeline
from pyaptamer.aptatrans.layers import EncoderPredictorConfig
103 changes: 73 additions & 30 deletions pyaptamer/aptatrans/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__author__ = ["nennomp"]
__all__ = ["AptaTrans"]

import os
from collections import OrderedDict
from collections.abc import Callable

Expand Down Expand Up @@ -44,6 +45,8 @@ class AptaTrans(nn.Module):
conv_layers : list[int], optional, default=[3, 3, 3]
List specifying the number of convolutional blocks in each convolutional
layer.
pretrained : bool, optional, default=False
If True, load the best weights from the pretrained model.

Attributes
----------
Expand Down Expand Up @@ -84,7 +87,7 @@ class AptaTrans(nn.Module):
>>> prot_embedding = EncoderPredictorConfig(128, 16, max_len=128)
>>> x_apta = torch.randint(high=16, size=(128, 10))
>>> x_prot = torch.randint(high=16, size=(128, 10))
>>> model = AptaTrans(apta_embedding, prot_embedding)
>>> model = AptaTrans(apta_embedding, prot_embedding, pretrained=False)
>>> imap = model.forward_imap(x_apta, x_prot)
>>> preds = model(x_apta, x_prot)
"""
Expand All @@ -98,6 +101,7 @@ def __init__(
n_heads: int = 8,
conv_layers: list[int] | None = None,
dropout: float = 0.1,
pretrained: bool = False,
) -> None:
"""
Raises
Expand Down Expand Up @@ -163,6 +167,9 @@ def __init__(
)
)

if pretrained:
self.load_pretrained_weights()

def _make_encoder(
self,
embedding_config: EncoderPredictorConfig,
Expand Down Expand Up @@ -256,45 +263,81 @@ def _make_layer(

return nn.Sequential(*layers)

def forward_encoders(
self,
x_apta: tuple[Tensor, Tensor],
x_prot: tuple[Tensor, Tensor],
):
"""Forward pass through the encoders only.
def load_pretrained_weights(self, store: bool = True) -> None:
"""Load pretrained model weights from hugging face.

This method performs a forward pass through the encoders, including the
token predictors, for pretraining.
If the weights are not found locally, they will be downloaded from hugging face.

Parameters
----------
x_apta, x_prot : tuple[Tensor, Tensor]
A tuple of tensors containing the features for masked tokens and secodnary
structure prediction, for aptamers and proteins, respectively. Shapes are
(batch_size (b1), seq_len (s1)) and (batch_size (b2), seq_len (s2)),
respectively.
store : bool, optional, default=True
If True, the pretrained weights will be saved locally. If False, the weights
will be downloaded but not saved to disk.
"""
path = os.path.relpath(
os.path.join(os.path.dirname(__file__), ".", "weights", "pretrained.pt")
)

if os.path.exists(path):
print(f"Loading pretrained weights from {path}...")
state_dict = torch.load(path, map_location=torch.device("cpu"))
else:
print("Downloading best weights from hugging face...")
url = (
"https://huggingface.co/gcos/pyaptamer-aptatrans/resolve/main/"
"pretrained.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
url=url,
model_dir=os.path.dirname(path),
map_location=torch.device("cpu"),
)

self.load_state_dict(state_dict, strict=True)

def forward_encoder(
self, x: tuple[Tensor, Tensor], encoder_type: str
) -> tuple[Tensor, Tensor]:
"""Forward pass through the aptamer or protein encoder.

This method performs a forward pass through the aptamer or protein encoder,
including the corresponding token predictor, for pretraining.

Parameters
----------
x : tuple[Tensor, Tensor]
A tuple of tensors containing the features for masked tokens and secondary
structure prediction, for aptamers or proteins. Shapes is (batch_size (b),
seq_len (s)).
encoder_type: str
A string indicating whether to use the aptamer or protein encoder. Options
are 'apta' or 'prot'.

Returns
-------
tuple[Tensor, Tensor], tuple[Tensor, Tensor]
tuple[Tensor, Tensor]
A tuple of tensors containing the predictions for masked tokens and
secondary structure, for aptamers and proteins, respectively. For aptamers,
the shapes are (b1, s1, n_embeddings (n1)) and (b1, s1, target_dim (t1)),
for the masked tokens and secondary structure, respectively. For proteins,
the shapes are (b2, s2, n_embeddings (n2)) and (b2, s2, target_dim (t2)),
secondary structure, for aptamers or proteins, depending on `encoder_type`.
Shapes are (b, s, n_embeddings (n)) and (b, s, target_dim (t)),
respectively.
"""
# pretrain aptamers' encoder
out_apta_mt = self.encoder_apta(x_apta[0])
out_apta_ss = self.encoder_apta(x_apta[1])
y_apta_mt, y_apta_ss = self.token_predictor_apta(out_apta_mt, out_apta_ss)

# pretrain proteins' encoder
out_prot_mt = self.encoder_prot(x_prot[0])
out_prot_ss = self.encoder_prot(x_prot[1])
y_prot_mt, y_prot_ss = self.token_predictor_prot(out_prot_mt, out_prot_ss)

return (y_apta_mt, y_apta_ss), (y_prot_mt, y_prot_ss)
Raises
-------
ValueError
If `encoder_type` is not 'apta' or 'prot'.
"""
if encoder_type == "apta": # pretrain aptamers' encoder
out_apta_mt = self.encoder_apta(x[0])
out_apta_ss = self.encoder_apta(x[1])
return self.token_predictor_apta(out_apta_mt, out_apta_ss)
elif encoder_type == "prot": # pretrain proteins' encoder
out_prot_mt = self.encoder_prot(x[0])
out_prot_ss = self.encoder_prot(x[1])
return self.token_predictor_prot(out_prot_mt, out_prot_ss)
else:
raise ValueError(
f"Unknown encoder_type: {encoder_type}. Options are 'apta' or 'prot'."
)

def forward_imap(self, x_apta: Tensor, x_prot: Tensor) -> Tensor:
"""Forward pass to compute the interaction map.
Expand Down
128 changes: 124 additions & 4 deletions pyaptamer/aptatrans/_model_lightning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""AptaTrans' deep neural network wrapper fro Lightning."""

__author__ = ["nennomp"]
__all__ = ["AptaTransLightning"]
__all__ = ["AptaTransLightning", "AptaTransEncoderLightning"]


import lightning as L
Expand All @@ -12,11 +12,11 @@


class AptaTransLightning(L.LightningModule):
"""LightningModule wrapper for the AptaTrans deep neural network [1]_.
"""LightningModule wrapper for training the AptaTrans deep neural network [1]_.

This class defines a LightningModule which acts as a wrapper for the AptaTrans
model, implemented as a `torch.nn.Module` in `pyaptamer.aptatrans._model.py`.
Specifically, it implementa two methods to make it compatible with lightning
Specifically, it implements two methods to make it compatible with lightning
training interface: (i) `training_step`, defining the training loop and (ii)
`configure_optimizers`, defining the optimizer used for training.

Expand Down Expand Up @@ -96,7 +96,7 @@ def training_step(
# (input aptamers, input proteins, ground-truth targets)
x_apta, x_prot, y = batch
y_hat = self.model(x_apta, x_prot)
loss = F.binary_cross_entropy(y_hat, y.float())
loss = F.binary_cross_entropy(y_hat.squeeze(0), y.float())
return loss

def configure_optimizers(self) -> torch.optim.Optimizer:
Expand All @@ -108,3 +108,123 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
betas=self.betas,
)
return optimizer


class AptaTransEncoderLightning(AptaTransLightning):
"""LightningModule wrapper for training the AptaTrans encoders [1]_.

This class defines a LightningModule which acts as a wrapper for the AptaTrans
encoders, implemented as a `torch.nn.Module` in `pyaptamer.aptatrans._model.py`.
Specifically, it implements two methods to make it compatible with lightning
training interface: (i) `training_step`, defining the training loop and (ii)
`configure_optimizers`, defining the optimizer used for training.

Parameters
----------
model: AptaTrans
An instance of the AptaTrans model.
encoder_type: str
A string indicating whether to use the aptamer or protein encoder. Options
are 'apta' or 'prot'.
lr: float, optional, default=1e-5
Learning rate for the optimizer.
weight_decay: float, optional, default=1e-5
Weight decay (L2 regularization) for the optimizer.
betas: tuple[float, float], optional, default=(0.9, 0.999)
Momentum coefficients for the Adam optimizer.
weight_mlm: float, optional, default=2.0
Weight for the masked language modeling (MLM) loss in the weighted total loss.
weight_ssp: float, optional, default=1.0
Weight for the secondary structure prediction (SSP) loss in the weighted
total loss.

References
----------
.. [1] Shin, Incheol, et al. "AptaTrans: a deep neural network for predicting
aptamer-protein interaction using pretrained encoders." BMC bioinformatics 24.1
(2023): 447.

Examples
--------
>>> import lightning as L
>>> import torch
>>> from pyaptamer.aptatrans import (
... AptaTrans,
... AptaTransEncoderLightning,
... EncoderPredictorConfig,
... )
>>> apta_embedding = EncoderPredictorConfig(128, 16, max_len=128)
>>> prot_embedding = EncoderPredictorConfig(128, 16, max_len=128)
>>> model = AptaTrans(apta_embedding, prot_embedding)
>>> # pretrain aptamer encoder
>>> model_lightning = AptaTransEncoderLightning(model, encoder_type="apta")
>>> x_apta_mlm = torch.randint(0, 125, (8, 128))
>>> x_apta_ssp = torch.randint(0, 125, (8, 128))
>>> y_mlm = torch.randint(0, 125, (8, 128))
>>> y_ssp = torch.randint(0, 8, (8, 128))
>>> train_dataloader = torch.utils.data.DataLoader(
... list(zip(x_apta_mlm, x_apta_ssp, y_mlm, y_ssp)),
... batch_size=4,
... shuffle=True,
... )
>>> trainer = L.Trainer(max_epochs=1)
>>> trainer.fit(model_lightning, train_dataloader) # doctest: +SKIP
>>> # pretrain protein encoder
>>> model_lightning = AptaTransEncoderLightning(model, encoder_type="prot")
>>> x_prot_mlm = torch.randint(0, 25, (8, 128))
>>> x_prot_ssp = torch.randint(0, 25, (8, 128))
>>> y_mlm = torch.randint(0, 25, (8, 128))
>>> y_ssp = torch.randint(0, 3, (8, 128))
>>> train_dataloader = torch.utils.data.DataLoader(
... list(zip(x_prot_mlm, x_prot_ssp, y_mlm, y_ssp)),
... batch_size=4,
... shuffle=True,
... )
>>> trainer = L.Trainer(max_epochs=1)
>>> trainer.fit(model_lightning, train_dataloader) # doctest: +SKIP
"""

def __init__(
self,
model: nn.Module,
encoder_type: str,
lr: float = 1e-4,
weight_decay: float = 1e-5,
betas: tuple[float, float] = (0.9, 0.999),
weight_mlm: float = 2.0,
weight_ssp: float = 1.0,
) -> None:
super().__init__(model, lr, weight_decay, betas)
self.encoder_type = encoder_type
self.weight_mlm = weight_mlm
self.weight_ssp = weight_ssp

def training_step(
self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
) -> Tensor:
"""Defines a single (mini-batch) step in the training loop.

The loss function is a weighted sum of the masked language modeling (MLM)
loss and the secondary structure prediction (SSP) loss.

Parameters
----------
batch: tuple[Tensor, Tensor, Tensor]
A batch of data containing aptamer sequences, protein sequences, and labels.
batch_idx: int
Index of the batch.

Returns
-------
Tensor
The computed loss for the batch.
"""
# (input masked, secondary structure, ground-truth targets)
x_mlm, x_ssp, y_mlm, y_ssp = batch
y_mlm_hat, y_ssp_hat = self.model.forward_encoder(
x=(x_mlm, x_ssp), encoder_type=self.encoder_type
)

loss_mlm = F.cross_entropy(y_mlm_hat, y_mlm.float())
loss_ssp = F.cross_entropy(y_ssp_hat, y_ssp.float())
return self.weight_mlm * loss_mlm + self.weight_ssp * loss_ssp
Loading
Loading