Skip to content

Commit

Permalink
Add bidirectional model (#118)
Browse files Browse the repository at this point in the history
Implement #116
  • Loading branch information
willdumm authored Feb 25, 2025
1 parent 2194ae4 commit a571cb9
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 3 deletions.
4 changes: 2 additions & 2 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
token_mask_of_aa_idxs,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
AA_PADDING_TOKEN,
)


Expand Down Expand Up @@ -134,7 +134,7 @@ def of_seriess(
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_AMBIG_IDX)
aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_PADDING_TOKEN)
aa_children_idxss = aa_parents_idxss.clone()
aa_subs_indicators = torch.zeros((pcp_count, max_aa_seq_len))

Expand Down
160 changes: 160 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
aa_idx_tensor_of_str_ambig,
PositionalEncoding,
split_heavy_light_model_outputs,
AA_PADDING_TOKEN,
)

from typing import Tuple
Expand Down Expand Up @@ -795,6 +796,165 @@ def predict(self, representation: Tensor):
return wiggle(super().predict(representation), beta)


def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed_dim=1):
"""Reverse the valid values in provided padded_tensors along the specified
dimension, keeping padding in the same place. For example, if the input is left-
aligned amino acid sequences and masks, move the padding to the right of the
reversed sequence. Equivalent to right-aligning the sequences then reversing them. A
sequence `123456XXXXX` becomes `654321XXXXX`.
The original padding mask remains valid for the returned tensor.
Args:
padded_tensors: (B, L) tensor of amino acid indices
padding_mask: (B, L) tensor of masks, with True indicating valid values, and False indicating padding values.
padding_value: The value to fill returned tensor where padding_mask is False.
reversed_dim: The dimension along which to reverse the tensor. When input is a batch of sequences to be reversed, the default value of 1 is the correct choice.
Returns:
The reversed tensor, with the same shape as padded_tensors, and with padding still specified by padding_mask.
"""
reversed_indices = torch.full_like(padded_tensors, padding_value)
reversed_indices[padding_mask] = padded_tensors.flip(reversed_dim)[
padding_mask.flip(reversed_dim)
]
return reversed_indices


class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel):
def __init__(
self,
nhead: int,
d_model_per_head: int,
dim_feedforward: int,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
known_token_count: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
self.known_token_count = known_token_count
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.dim_feedforward = dim_feedforward
# Forward direction components
self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.forward_amino_acid_embedding = nn.Embedding(
self.known_token_count, self.d_model
)
self.forward_encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
)
self.forward_encoder = nn.TransformerEncoder(
self.forward_encoder_layer, layer_count
)

# Reverse direction components
self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.reverse_amino_acid_embedding = nn.Embedding(
self.known_token_count, self.d_model
)
self.reverse_encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
)
self.reverse_encoder = nn.TransformerEncoder(
self.reverse_encoder_layer, layer_count
)

# Output layers
self.combine_features = nn.Linear(2 * self.d_model, self.d_model)
self.output = nn.Linear(self.d_model, output_dim)

self.init_weights()

def init_weights(self) -> None:
initrange = 0.1
self.combine_features.bias.data.zero_()
self.combine_features.weight.data.uniform_(-initrange, initrange)
self.output.bias.data.zero_()
self.output.weight.data.uniform_(-initrange, initrange)

def single_direction_represent_sequence(
self,
indices: Tensor,
mask: Tensor,
embedding: nn.Embedding,
pos_encoder: PositionalEncoding,
encoder: nn.TransformerEncoder,
) -> Tensor:
"""Process sequence through one direction of the model."""
embedded = embedding(indices) * math.sqrt(self.d_model)
embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2)
return encoder(embedded, src_key_padding_mask=~mask)

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
# This is okay, as long as there are no masked ambiguities in the
# interior of the sequence... Otherwise it should also work for paired seqs.

# Forward direction - normal processing
forward_repr = self.single_direction_represent_sequence(
amino_acid_indices,
mask,
self.forward_amino_acid_embedding,
self.forward_pos_encoder,
self.forward_encoder,
)

# Reverse direction - flip sequences and masks
reversed_indices = reverse_padded_tensors(
amino_acid_indices, mask, AA_PADDING_TOKEN
)

reverse_repr = self.single_direction_represent_sequence(
reversed_indices,
mask,
self.reverse_amino_acid_embedding,
self.reverse_pos_encoder,
self.reverse_encoder,
)

# un-reverse to align with forward representation
aligned_reverse_repr = reverse_padded_tensors(reverse_repr, mask, 0.0)

# Combine features
combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1)
return self.combine_features(combined)

def predict(self, representation: Tensor) -> Tensor:
# Output layer
return self.output(representation).squeeze(-1)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
return self.predict(self.represent(amino_acid_indices, mask))

@property
def hyperparameters(self):
return {
"nhead": self.nhead,
"d_model_per_head": self.d_model_per_head,
"dim_feedforward": self.dim_feedforward,
"layer_count": self.forward_encoder.num_layers,
"dropout_prob": self.forward_pos_encoder.dropout.p,
"output_dim": self.output.out_features,
"known_token_count": self.known_token_count,
}


class BidirectionalTransformerBinarySelectionModelWiggleAct(
BidirectionalTransformerBinarySelectionModel
):
"""Here the beta parameter is fixed at 0.3."""

def predict(self, representation: Tensor):
return wiggle(super().predict(representation), 0.3)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

Expand Down
3 changes: 3 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
AA_AMBIG_IDX = len(AA_STR_SORTED)
# Used for padding amino acid sequences to the same length. Differentiated by
# name in case we add a padding token other than AA_AMBIG_IDX later.
AA_PADDING_TOKEN = AA_AMBIG_IDX

CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)]
STOP_CODONS = ["TAA", "TAG", "TGA"]
Expand Down
36 changes: 35 additions & 1 deletion tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import netam.framework as framework
from netam.common import BIG
from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito
from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel
from netam.models import (
SHMoofModel,
RSSHMoofModel,
IndepRSCNNModel,
reverse_padded_tensors,
)


@pytest.fixture
Expand Down Expand Up @@ -114,3 +119,32 @@ def test_standardize_model_rates(mini_rsburrito):
mini_rsburrito.standardize_model_rates()
vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate()
assert np.isclose(vrc01_rate_14, 1.0)


def test_reverse_padded_tensors():
# Here we just test that we can apply the function twice and get the
# original input back.
test_tensor = torch.tensor(
[
[1, 2, 3, 4, 0, 0],
[1, 2, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 0],
[1, 2, 3, 4, 5, 6],
[1, 2, 0, 0, 0, 0],
]
)
true_reversed = torch.tensor(
[
[4, 3, 2, 1, 0, 0],
[2, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0],
[6, 5, 4, 3, 2, 1],
[2, 1, 0, 0, 0, 0],
]
)
mask = test_tensor > 0
reversed_tensor = reverse_padded_tensors(test_tensor, mask, 0)
assert torch.equal(true_reversed, reversed_tensor)
assert torch.equal(test_tensor, reverse_padded_tensors(reversed_tensor, mask, 0))

0 comments on commit a571cb9

Please sign in to comment.