Skip to content

Commit

Permalink
Keep datasets on CPU (#120)
Browse files Browse the repository at this point in the history
Fixes #119 keeping DXSM datasets stored on CPU, except when batches are used in training.
  • Loading branch information
willdumm authored Feb 22, 2025
1 parent d0f6bed commit 2194ae4
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 63 deletions.
26 changes: 0 additions & 26 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,32 +333,6 @@ def chunked(iterable, n):
yield chunk


def assume_single_sequence_is_heavy_chain(seq_arg_idx=0):
"""Wraps a function that takes a heavy/light sequence pair as its first argument and
returns a tuple of results.
The wrapped function will assume that if the first argument is a string, it is a
heavy chain sequence, and in that case will return only the heavy chain result.
"""

def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
seq = args[seq_arg_idx]
if isinstance(seq, str):
seq = (seq, "")
args = list(args)
args[seq_arg_idx] = seq
res = function(*args, **kwargs)
return res[0]
else:
return function(*args, **kwargs)

return wrapper

return decorator


def heavy_chain_shim(paired_evaluator):
"""Returns a function that evaluates only heavy chains given a paired evaluator."""

Expand Down
15 changes: 3 additions & 12 deletions netam/dasm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Defining the deep natural selection model (DNSM)."""

import copy

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -72,13 +70,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -93,11 +84,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_codon_probs).all():
print(f"Found a non-finite neutral_codon_prob")
print("Found a non-finite neutral_codon_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -137,7 +128,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.codon_parents_idxss = self.codon_parents_idxss.to(device)
self.codon_children_idxss = self.codon_children_idxss.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
Expand Down
14 changes: 3 additions & 11 deletions netam/ddsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
import copy
from typing import Tuple


Expand All @@ -24,13 +23,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -45,11 +37,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_aa_probs).all():
print(f"Found a non-finite neutral_aa_probs")
print("Found a non-finite neutral_aa_probs")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -85,7 +77,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_children_idxss = self.aa_children_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
Expand Down
15 changes: 3 additions & 12 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Defining the deep natural selection model (DNSM)."""

import copy

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -37,13 +35,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -60,11 +51,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_aa_mut_probs).all():
print(f"Found a non-finite neutral_aa_mut_prob")
print("Found a non-finite neutral_aa_mut_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -101,7 +92,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
Expand Down
14 changes: 13 additions & 1 deletion netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def __init__(
self.masks = masks
self.aa_subs_indicators = aa_subs_indicators
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
if self.multihit_model is not None:
# We want these parameters to act like fixed data. This is essential
# for multithreaded branch length optimization to work.
self.multihit_model = self.multihit_model.to("cpu")
self.multihit_model.values.requires_grad_(False)

assert len(self.nt_parents) == len(self.nt_children)
Expand All @@ -84,6 +85,14 @@ def __init__(
self._branch_lengths = branch_lengths
self.update_neutral_probs()

def __post_init__(self):
self.move_data_to_device("cpu")

@abstractmethod
def move_data_to_device(self, device):
"""Move all tensors stored by the dataset to the given device."""
pass

@classmethod
def of_seriess(
cls,
Expand Down Expand Up @@ -284,6 +293,9 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()

def to(self, device):
self.device = device

@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,8 @@ def standardize_and_optimize_branch_lengths(self, **optimization_kwargs):
dataset.branch_lengths = self.find_optimal_branch_lengths(
dataset, **optimization_kwargs
)
dataset.to(device)
self.model.to(device)
dataset.to(device)

def standardize_and_use_yun_approx_branch_lengths(self):
"""Yun Song's approximation to the branch lengths.
Expand Down

0 comments on commit 2194ae4

Please sign in to comment.