Skip to content

Commit

Permalink
Add convenient Crepe.represent_sequences method (#117)
Browse files Browse the repository at this point in the history
This addresses #115, adding `Crepe.represent_sequences`, and a number of supporting methods on D*SM model classes.

It also eliminates the option of providing non-paired sequences to D*SM model methods that take a string (and therefore also all Crepe methods that call them). These methods now require amino acid sequences to be provided in `(heavy_chain, light_chain)` tuples, where a missing chain sequence can be represented by the empty string.

The represent_sequences function returns a tensor for each heavy-light pair provided to it, while `Crepe.__call__` returns a pair of tensors (one for heavy, one for light chain) for each heavy-light pair provided to it. This seems to me the correct choice, but there could be justification for splitting the embedding tensors returned by represent_sequences on heavy/light boundaries.
  • Loading branch information
willdumm authored Feb 14, 2025
1 parent 457d81f commit 954c28c
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 83 deletions.
22 changes: 22 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
SMALL_PROB = 1e-6


def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

device = predictions.device
batch_size, L, _ = predictions.shape
batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L)
sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1)

# Create a mask for valid positions (where aa_parents_idxs is less than 20)
valid_mask = aa_parents_idxs < 20

# Only update the predictions for valid positions
predictions[
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = fill

return predictions


def combine_and_pad_tensors(first, second, padding_idxs, fill=float("nan")):
res = torch.full(
(first.shape[0] + second.shape[0] + len(padding_idxs),) + first.shape[1:], fill
Expand Down
24 changes: 1 addition & 23 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from netam.common import (
stack_heterogeneous,
BIG,
zap_predictions_along_diagonal,
)
import netam.framework as framework
import netam.molevol as molevol
Expand Down Expand Up @@ -449,25 +449,3 @@ def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kw
"""The worker used for parallel branch length optimization."""
burrito = burrito_class(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)


def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

device = predictions.device
batch_size, L, _ = predictions.shape
batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L)
sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1)

# Create a mask for valid positions (where aa_parents_idxs is less than 20)
valid_mask = aa_parents_idxs < 20

# Only update the predictions for valid positions
predictions[
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = fill

return predictions
11 changes: 11 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ def __call__(self, sequences, **kwargs):
sequences, encoder=self.encoder, **kwargs
)

def represent_sequences(self, sequences):
"""Represent a list of sequences in the model's embedding space.
This is implemented only for D*SM models.
"""
if isinstance(sequences, str):
raise ValueError(
"Expected a list of sequences for call on crepe, but got a single string instead."
)
return list(self.model.represent_aa_str(seq) for seq in sequences)

@property
def device(self):
return next(self.model.parameters()).device
Expand Down
92 changes: 54 additions & 38 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.common import (
chunk_function,
assume_single_sequence_is_heavy_chain,
zap_predictions_along_diagonal,
)
from netam.sequences import (
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
aa_idx_tensor_of_str_ambig,
PositionalEncoding,
set_wt_to_nan,
split_heavy_light_model_outputs,
)

from typing import Tuple
Expand Down Expand Up @@ -87,6 +87,9 @@ def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
outputs = self(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)

def represent_aa_str(self, *args, **kwargs):
raise NotImplementedError("represent_aa_str is implemented on D*SM models only")


class KmerModel(ModelBase):
def __init__(self, kmer_length):
Expand Down Expand Up @@ -576,33 +579,25 @@ def predictions_of_sequences(self, sequences, **kwargs):
predictions for wildtype amino acids are unconstrained in training and therefore
meaningless.
"""
res = self.evaluate_sequences(sequences, **kwargs)
if self.hyperparameters["output_dim"] >= 20:
return [set_wt_to_nan(pred, seq) for pred, seq in zip(res, sequences)]
else:
return res
return self.evaluate_sequences(sequences, **kwargs)

def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)
return list(self.selection_factors_of_aa_str(seq) for seq in sequences)

@assume_single_sequence_is_heavy_chain(1)
def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Insertion of model tokens will be done automatically.
Args:
aa_str: A string of amino acids. If a string, we assume this is a light chain sequence.
Otherwise it should be a tuple, with the first element being the heavy chain and the second element being the light chain sequence.
def prepare_aa_str(self, heavy_chain, light_chain):
"""Prepare a pair of amino acid sequences for input to the model.
Returns:
A tuple of numpy arrays of the same length as the input strings representing
the level of selection for each amino acid at each site.
A tuple of two tensors, the first being the index-encoded parent amino acid
sequences and the second being the mask tensor.
Although both represent a single sequence, they include a per-sequence first
dimension for direct ingestion by the model.
"""

aa_str, added_indices = sequences.prepare_heavy_light_pair(
*aa_sequence, self.hyperparameters["known_token_count"], is_nt=False
heavy_chain,
light_chain,
self.hyperparameters["known_token_count"],
is_nt=False,
)
aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
if torch.any(aa_idxs >= self.hyperparameters["known_token_count"]):
Expand All @@ -614,25 +609,46 @@ def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor:
# test_common.py::test_compare_mask_tensors.
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)
return aa_idxs.unsqueeze(0), mask.unsqueeze(0)

with torch.no_grad():
model_out = (
self(
aa_idxs.unsqueeze(0),
mask.unsqueeze(0),
)
.squeeze(0)
.exp()
def represent_aa_str(self, aa_sequence):
"""Call the forward method of the model on the provided heavy, light pair of AA
sequences."""
if isinstance(aa_sequence, str) or len(aa_sequence) != 2:
raise ValueError(
"aa_sequence must be a pair of strings, with the first element being the heavy chain sequence and the second element being the light chain sequence."
)
inputs = self.prepare_aa_str(*aa_sequence)
with torch.no_grad():
return self.represent(*inputs).squeeze(0)

# Now split into heavy and light chain results:
sequence_mask = torch.ones(len(model_out), dtype=bool)
if len(added_indices) > 0:
sequence_mask[torch.tensor(added_indices)] = False
masked_model_out = model_out[sequence_mask]
heavy_chain = masked_model_out[: len(aa_sequence[0])]
light_chain = masked_model_out[len(aa_sequence[0]) :]
return heavy_chain, light_chain
def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Insertion of model tokens will be done automatically.
Args:
aa_str: A heavy, light chain pair of amino acid sequences.
Returns:
A tuple of numpy arrays of the same length as the input strings representing
the level of selection for each amino acid at each site.
"""
if isinstance(aa_sequence, str) or len(aa_sequence) != 2:
raise ValueError(
"aa_sequence must be a pair of strings, with the first element being the heavy chain sequence and the second element being the light chain sequence."
)
idx_seq, mask = self.prepare_aa_str(*aa_sequence)
with torch.no_grad():
result = torch.exp(self.forward(idx_seq, mask))
if self.hyperparameters["output_dim"] >= 20:
result = zap_predictions_along_diagonal(
result, idx_seq, fill=float("nan")
).squeeze(0)
else:
result = result.squeeze(0)
return split_heavy_light_model_outputs(result, idx_seq.squeeze(0))


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand Down
34 changes: 18 additions & 16 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ def heavy_light_mask_of_aa_idxs(aa_idxs):
return aa_idxs < AA_AMBIG_IDX


def split_heavy_light_model_outputs(result, aa_idxs):
"""Split a tensor whose first dimension corresponds to amino acid positions into
heavy chain and light chain components.
Args:
result: The tensor to split.
aa_idxs: The amino acid indices corresponding to the tensor, as presented to the model (including any special tokens).
Returns:
Tuple[torch.Tensor, torch.Tensor]: The heavy chain and light chain components of the input tensor.
"""
# Now split into heavy and light chain results:
heavy_mask, light_mask = heavy_light_mask_of_aa_idxs(aa_idxs).values()
heavy_chain = result[heavy_mask]
light_chain = result[light_mask]
return heavy_chain, light_chain


def dataset_inputs_of_pcp_df(pcp_df, known_token_count):
parents = []
children = []
Expand Down Expand Up @@ -523,19 +541,3 @@ def aa_onehot_tensor_of_str(aa_str):
aa_indices_parent = aa_idx_array_of_str(aa_str)
aa_onehot[torch.arange(len(aa_str)), aa_indices_parent] = 1
return aa_onehot


# TODO is this not the same as zap_wt_predictions?
def set_wt_to_nan(predictions: torch.Tensor, aa_sequence: str) -> torch.Tensor:
"""Set the wild-type predictions to NaN.
Modifies the supplied predictions tensor in-place and returns it. For sites
containing special tokens, all predictions are set to NaN.
"""
wt_idxs = aa_idx_tensor_of_str(aa_sequence)
token_mask = wt_idxs < AA_AMBIG_IDX
predictions[token_mask][torch.arange(token_mask.sum()), wt_idxs[token_mask]] = (
float("nan")
)
predictions[~token_mask] = float("nan")
return predictions
14 changes: 8 additions & 6 deletions tests/test_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from netam.framework import load_crepe
from netam.sequences import set_wt_to_nan
from netam.sequences import aa_idx_tensor_of_str_ambig


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -68,14 +68,16 @@ def test_old_crepe_outputs():
dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint")

ddsm_vals = torch.nan_to_num(
set_wt_to_nan(
torch.load("tests/old_models/ddsm_output", weights_only=True), example_seq
),
zap_predictions_along_diagonal(
torch.load("tests/old_models/ddsm_output", weights_only=True).unsqueeze(0),
aa_idx_tensor_of_str_ambig(example_seq).unsqueeze(0),
fill=float("nan"),
).squeeze(0),
0.0,
)
dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True)

ddsm_result = torch.nan_to_num(ddsm_crepe([example_seq])[0], 0.0)
dnsm_result = dnsm_crepe([example_seq])[0]
ddsm_result = torch.nan_to_num(ddsm_crepe([(example_seq, "")])[0][0], 0.0)
dnsm_result = dnsm_crepe([(example_seq, "")])[0][0]
assert torch.allclose(ddsm_result, ddsm_vals)
assert torch.allclose(dnsm_result, dnsm_vals)

0 comments on commit 954c28c

Please sign in to comment.