From 90103f4b2c654ecf8e76e85e0f0f10b984884c24 Mon Sep 17 00:00:00 2001 From: Danny Date: Tue, 24 Dec 2024 08:56:01 -0800 Subject: [PATCH] format fix --- .secrets.baseline | 4 +- LICENSE/third_party.txt | 2 +- VERSION | 2 +- sub-packages/bionemo-moco/documentation.md | 2 - .../scripts/clean_documentation.py | 16 +++ .../prior/continuous/gaussian.py | 3 +- .../distributions/prior/continuous/utils.py | 5 +- .../distributions/prior/discrete/custom.py | 3 +- .../moco/distributions/prior/discrete/mask.py | 3 +- .../distributions/prior/discrete/uniform.py | 3 +- .../bionemo/moco/distributions/time/beta.py | 1 + .../moco/distributions/time/logit_normal.py | 1 + .../moco/distributions/time/uniform.py | 1 + .../moco/interpolants/base_interpolant.py | 7 +- .../continuous/continuous_flow_matching.py | 3 +- .../continuous_time/continuous/vdm.py | 3 +- .../discrete/discrete_flow_matching.py | 3 +- .../continuous_time/discrete/mdlm.py | 3 +- .../discrete_time/continuous/ddpm.py | 3 +- .../discrete_time/discrete/d3pm.py | 6 +- .../schedules/discrete_noise_schedules.py | 10 +- .../schedules/inference_time_schedules.py | 112 +++++++++++++----- .../noise/continuous_noise_transforms.py | 3 +- .../noise/continuous_snr_transforms.py | 3 +- .../noise/discrete_noise_schedules.py | 15 +-- .../prior/continuous/test_gaussian.py | 1 + .../prior/discrete/test_custom.py | 1 + .../distributions/prior/discrete/test_mask.py | 1 + .../prior/discrete/test_uniform.py | 1 + .../time/test_time_distribution.py | 1 + .../test_continuous_flow_matching.py | 1 + .../continuous/test_optimal_transport.py | 2 +- .../continuous_time/continuous/test_vdm.py | 1 + .../discrete/test_discrete_flow_matching.py | 1 + .../continuous_time/discrete/test_mdlm.py | 1 + .../discrete_time/continuous/test_ddpm.py | 1 + .../discrete_time/discrete/test_d3pm.py | 1 + .../noise/test_discrete_noise_schedule.py | 1 + .../test_continuous_noise_transforms.py | 1 + .../test_continuous_snr_transforms.py | 1 + .../test_discrete_noise_schedules.py | 1 + ...chedule.py => test_inference_schedules.py} | 33 +++++- 42 files changed, 192 insertions(+), 75 deletions(-) rename sub-packages/bionemo-moco/tests/bionemo/moco/schedules/{test_infernece_schedule.py => test_inference_schedules.py} (79%) diff --git a/.secrets.baseline b/.secrets.baseline index c0824519f8..2247dd9f8c 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -139,9 +139,9 @@ "filename": "pyproject.toml", "hashed_secret": "79670e9c9d1c7ea5b81a96a2053d81437712c78e", "is_verified": false, - "line_number": 47 + "line_number": 48 } ] }, - "generated_at": "2024-11-01T22:26:03Z" + "generated_at": "2024-12-24T16:57:17Z" } diff --git a/LICENSE/third_party.txt b/LICENSE/third_party.txt index 3270452836..16e5e85511 100644 --- a/LICENSE/third_party.txt +++ b/LICENSE/third_party.txt @@ -844,4 +844,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/VERSION b/VERSION index 9f8e9b69a3..d3827e75a5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0 \ No newline at end of file +1.0 diff --git a/sub-packages/bionemo-moco/documentation.md b/sub-packages/bionemo-moco/documentation.md index 6df7651f7a..f886a96f9c 100644 --- a/sub-packages/bionemo-moco/documentation.md +++ b/sub-packages/bionemo-moco/documentation.md @@ -1,3 +1 @@ # Table of Contents - - diff --git a/sub-packages/bionemo-moco/scripts/clean_documentation.py b/sub-packages/bionemo-moco/scripts/clean_documentation.py index 79917557e0..6a4ecc76ec 100644 --- a/sub-packages/bionemo-moco/scripts/clean_documentation.py +++ b/sub-packages/bionemo-moco/scripts/clean_documentation.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import re diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/gaussian.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/gaussian.py index 0d02bcdfe7..7b274ec2df 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/gaussian.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/gaussian.py @@ -18,9 +18,10 @@ import torch from jaxtyping import Bool, Float +from torch import Tensor + from bionemo.moco.distributions.prior.continuous.utils import remove_center_of_mass from bionemo.moco.distributions.prior.distribution import PriorDistribution -from torch import Tensor class GaussianPrior(PriorDistribution): diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/utils.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/utils.py index 3e20eceada..9a9bf9dfe5 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/utils.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/continuous/utils.py @@ -16,13 +16,10 @@ from typing import Optional -from jaxtyping import Float from torch import Tensor -def remove_center_of_mass( - data: Float[Tensor, "* N D"], mask: Optional[Float[Tensor, "* N D"]] = None -) -> Float[Tensor, "* N D"]: +def remove_center_of_mass(data: Tensor, mask: Optional[Tensor] = None) -> Tensor: """Calculates the center of mass (CoM) of the given data. Args: diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/custom.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/custom.py index 9f8f38461e..6b2f040b77 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/custom.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/custom.py @@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union import torch -from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution from torch import Tensor +from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution + class DiscreteCustomPrior(DiscretePriorDistribution): """A subclass representing a discrete custom prior distribution. diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/mask.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/mask.py index 141d94b89d..cd8031cc99 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/mask.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/mask.py @@ -17,9 +17,10 @@ from typing import Optional, Tuple, Union import torch -from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution from torch import Tensor +from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution + class DiscreteMaskedPrior(DiscretePriorDistribution): """A subclass representing a Discrete Masked prior distribution.""" diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/uniform.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/uniform.py index 234f9a328b..53a71b601e 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/uniform.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/prior/discrete/uniform.py @@ -17,9 +17,10 @@ from typing import Optional, Tuple, Union import torch -from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution from torch import Tensor +from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution + class DiscreteUniformPrior(DiscretePriorDistribution): """A subclass representing a discrete uniform prior distribution.""" diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/beta.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/beta.py index 36485c4438..e33d334670 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/beta.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/beta.py @@ -18,6 +18,7 @@ import torch from jaxtyping import Bool, Float + from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.distributions.time.utils import float_time_to_index diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/logit_normal.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/logit_normal.py index 0da369cf61..93267aae83 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/logit_normal.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/logit_normal.py @@ -18,6 +18,7 @@ import torch from jaxtyping import Bool, Float + from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.distributions.time.utils import float_time_to_index diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/uniform.py b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/uniform.py index be944e547a..dcbeae6623 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/uniform.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/distributions/time/uniform.py @@ -18,6 +18,7 @@ import torch from jaxtyping import Bool, Float + from bionemo.moco.distributions.time.distribution import TimeDistribution diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/base_interpolant.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/base_interpolant.py index a3cc053c34..3899e9deda 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/base_interpolant.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/base_interpolant.py @@ -19,9 +19,8 @@ from typing import Optional, Type, TypeVar, Union import torch +from jaxtyping import Bool from torch import Tensor -from jaxtyping import Bool, Float - from bionemo.moco.distributions.prior.distribution import PriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution @@ -219,9 +218,7 @@ def to_device(self, device: str): setattr(self, attr_name, getattr(self, attr_name).to(device)) return self - def clean_mask_center( - self, data: Float[Tensor, "* N D"], mask: Optional[Float[Tensor, "* N"]] = None, center: Bool = False - ) -> Float[Tensor, "* N D"]: + def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor: """Returns a clean tensor that has been masked and/or centered based on the function arguments. Args: diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py index 1751749bb9..01c56eb1e9 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py @@ -20,12 +20,13 @@ import torch import torch.nn as nn from jaxtyping import Bool, Float +from torch import Tensor + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.prior.distribution import PriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.interpolants.base_interpolant import Interpolant, PredictionType, pad_like, string_to_enum from bionemo.moco.interpolants.continuous_time.continuous.optimal_transport import OTSampler -from torch import Tensor class OptimalTransportType(Enum): diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/vdm.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/vdm.py index 350fd07ccc..9badcde21c 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/vdm.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/continuous/vdm.py @@ -21,12 +21,13 @@ import torch.nn as nn import torch.nn.functional as F from jaxtyping import Bool, Float +from torch import Tensor + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.prior.distribution import PriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.interpolants.base_interpolant import Interpolant, PredictionType, pad_like, string_to_enum from bionemo.moco.schedules.noise.continuous_snr_transforms import ContinuousSNRTransform -from torch import Tensor class VDM(Interpolant): diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/discrete_flow_matching.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/discrete_flow_matching.py index 23ded0fb8f..94d2e7671c 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/discrete_flow_matching.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/discrete_flow_matching.py @@ -20,11 +20,12 @@ import torch.nn as nn import torch.nn.functional as F from jaxtyping import Bool, Float +from torch import Tensor + from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.interpolants.base_interpolant import Interpolant, pad_like -from torch import Tensor class DiscreteFlowMatcher(Interpolant): diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/mdlm.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/mdlm.py index 4aaa998cd2..07d109b3a2 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/mdlm.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/mdlm.py @@ -17,11 +17,12 @@ from typing import Optional import torch +from torch import Tensor + from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.interpolants.base_interpolant import Interpolant, pad_like from bionemo.moco.schedules.noise.continuous_noise_transforms import ContinuousExpNoiseTransform -from torch import Tensor class MDLM(Interpolant): diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/continuous/ddpm.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/continuous/ddpm.py index 8891368126..990c37da65 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/continuous/ddpm.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/continuous/ddpm.py @@ -18,9 +18,10 @@ from typing import Optional, Union import torch -from torch import Tensor import torch.nn as nn from jaxtyping import Bool, Float +from torch import Tensor + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.prior.distribution import PriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/discrete/d3pm.py b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/discrete/d3pm.py index b91b23f3a8..f10690de8b 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/discrete/d3pm.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/interpolants/discrete_time/discrete/d3pm.py @@ -17,10 +17,11 @@ from typing import Optional, Tuple import torch -from torch import Tensor import torch.nn as nn import torch.nn.functional as F from jaxtyping import Float +from torch import Tensor + from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution from bionemo.moco.distributions.time.distribution import TimeDistribution from bionemo.moco.interpolants.base_interpolant import Interpolant @@ -28,7 +29,6 @@ from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteNoiseSchedule - class D3PM(Interpolant): """A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant.""" @@ -330,7 +330,7 @@ def loss( if aggregate: loss = loss.mean() if vb_scale > 0: - loss += vb_loss.mean() # type: ignore + loss += vb_loss.mean() # type: ignore return loss return loss, vb_loss diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/discrete_noise_schedules.py b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/discrete_noise_schedules.py index fd628aec29..6775997cbd 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/discrete_noise_schedules.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/discrete_noise_schedules.py @@ -18,8 +18,8 @@ from typing import Optional, Union import torch -from torch import Tensor from jaxtyping import Float +from torch import Tensor from bionemo.moco.interpolants.base_interpolant import string_to_enum from bionemo.moco.schedules.utils import TimeDirection @@ -64,9 +64,7 @@ def generate_schedule( return schedule @abstractmethod - def _generate_schedule( - self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu" - ) -> Tensor: + def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor: """Generate the time schedule as a list. Args: @@ -115,9 +113,7 @@ def __init__(self, nsteps: int, nu: Float = 1.0, s: Float = 0.008): self.nu = nu self.s = s - def _generate_schedule( - self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu" - ) -> Tensor: + def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor: """Generate the cosine noise schedule. Args: diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/inference_time_schedules.py b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/inference_time_schedules.py index 4cb4206d7c..34c647e48c 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/inference_time_schedules.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/inference_time_schedules.py @@ -18,8 +18,8 @@ from typing import Optional, Union import torch +from jaxtyping import Float from torch import Tensor -from jaxtyping import Bool, Float from bionemo.moco.interpolants.base_interpolant import string_to_enum from bionemo.moco.schedules.utils import TimeDirection @@ -32,6 +32,8 @@ def __init__( self, nsteps: int, min_t: Float = 0, + padding: Float = 0, + dilation: Float = 0, direction: TimeDirection = TimeDirection.UNIFIED, device: Union[str, torch.device] = "cpu", ): @@ -40,24 +42,27 @@ def __init__( Args: nsteps (int): Number of time steps. min_t (Float): minimum time value defaults to 0. + padding (Float): padding time value defaults to 0. + dilation (Float): dilation time value defaults to 0 ie the number of replicates. direction (Optional[str]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). device (Optional[str]): Device to place the schedule on (default is "cpu"). """ self.nsteps = nsteps self.min_t = min_t + self.padding = padding + self.dilation = dilation self.direction = string_to_enum(direction, TimeDirection) self.device = device @abstractmethod def generate_schedule( - self, nsteps: Optional[int] = None, full: Bool = False, device: Optional[Union[str, torch.device]] = None + self, nsteps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None ) -> Tensor: """Generate the time schedule as a tensor. Args: nsteps (Optioanl[int]): Number of time steps. If None, uses the value from initialization. - full (Bool): Boolean to return entire schedule or just the needed components. device (Optional[str]): Device to place the schedule on (default is "cpu"). """ pass @@ -85,10 +90,12 @@ def discretize( if device is None: device = self.device if schedule is None: - schedule = self.generate_schedule(nsteps, full=True, device=device) + schedule = self.generate_schedule(nsteps, device=device) if self.direction == TimeDirection.UNIFIED: + schedule = torch.cat((schedule, torch.ones((1,), device=schedule.device))) dt = schedule[1:] - schedule[:-1] else: + schedule = torch.cat((schedule, torch.zeros((1,), device=schedule.device))) dt = -1 * (schedule[1:] - schedule[:-1]) return dt @@ -110,6 +117,8 @@ def discretize( Returns: Tensor: A tensor of time deltas. """ + if self.padding > 0 or self.dilation > 0: + raise NotImplementedError("discreteize is not implemented for discrete schedules with padding or dilation") if device is None: device = self.device return torch.full( @@ -126,6 +135,8 @@ def __init__( self, nsteps: int, min_t: Float = 0, + padding: Float = 0, + dilation: Float = 0, direction: TimeDirection = TimeDirection.UNIFIED, device: Union[str, torch.device] = "cpu", ): @@ -134,33 +145,43 @@ def __init__( Args: nsteps (int): Number of time steps. min_t (Float): minimum time value defaults to 0. + padding (Float): padding time value defaults to 0. + dilation (Float): dilation time value defaults to 0 ie the number of replicates. direction (Optional[str]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). device (Optional[str]): Device to place the schedule on (default is "cpu"). """ - super().__init__(nsteps, min_t, direction, device) + super().__init__(nsteps, min_t, padding, dilation, direction, device) def generate_schedule( - self, nsteps: Optional[int] = None, full: Bool = False, device: Optional[Union[str, torch.device]] = None + self, nsteps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None ) -> Tensor: """Generate the linear time schedule as a tensor. Args: nsteps (Optional[int]): Number of time steps. If None uses the value from initialization. - full (Bool): Whether to return the full scheudle defaults to False. device (Optional[str]): Device to place the schedule on (default is "cpu"). Returns: Tensor: A tensor of time steps. + Tensor: A tensor of time steps. """ if device is None: device = self.device if nsteps is None: nsteps = self.nsteps - if full: - nsteps += 1 + nsteps -= self.padding + dilation = self.dilation + 1 + if dilation > 1: + if nsteps % dilation != 0: + raise ValueError(f"nsteps ({nsteps}) is not divisible by dilation + 1 ({dilation})") + nsteps = int(nsteps / self.dilation) schedule = torch.arange(nsteps).to(device=device) + if dilation > 1: + schedule = schedule.repeat_interleave(dilation) if self.direction == TimeDirection.DIFFUSION: schedule = schedule.flip(0) + if self.padding > 0: + schedule = torch.cat((schedule, schedule[-1] * torch.ones(self.padding, device=device))) return schedule @@ -171,6 +192,8 @@ def __init__( self, nsteps: int, min_t: Float = 0, + padding: Float = 0, + dilation: Float = 0, direction: TimeDirection = TimeDirection.UNIFIED, device: Union[str, torch.device] = "cpu", ): @@ -179,39 +202,51 @@ def __init__( Args: nsteps (int): Number of time steps. min_t (Float): minimum time value defaults to 0. + padding (Float): padding time value defaults to 0. + dilation (Float): dilation time value defaults to 0 ie the number of replicates. direction (Optional[str]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). device (Optional[str]): Device to place the schedule on (default is "cpu"). """ - super().__init__(nsteps, min_t, direction, device) + super().__init__(nsteps, min_t, padding, dilation, direction, device) def generate_schedule( self, nsteps: Optional[int] = None, - full: Bool = False, device: Optional[Union[str, torch.device]] = None, ) -> Tensor: """Generate the linear time schedule as a tensor. Args: nsteps (Optional[int]): Number of time steps. If None uses the value from initialization. - full (Bool): Whether to return the full scheudle defaults to False. device (Optional[str]): Device to place the schedule on (default is "cpu"). Returns: Tensor: A tensor of time steps. + Tensor: A tensor of time steps. """ if device is None: device = self.device if nsteps is None: nsteps = self.nsteps + nsteps -= self.padding + dilation = self.dilation + 1 + if dilation > 1: + if nsteps % dilation != 0: + raise ValueError(f"nsteps ({nsteps}) is not divisible by dilation + 1 ({dilation})") + nsteps = int(nsteps / dilation) schedule = torch.linspace(0, 1, nsteps + 1).to(device=device) + if self.min_t > 0: schedule = torch.clamp(schedule, min=self.min_t) + + schedule = schedule[:-1] + if dilation > 1: + schedule = schedule.repeat_interleave(dilation) + if self.padding > 0: + schedule = torch.cat((schedule, torch.ones(self.padding, device=device))) if self.direction == TimeDirection.DIFFUSION: schedule = 1 - schedule # schedule.flip(0) - if full: - return schedule - return schedule[:-1] + return schedule class PowerInferenceSchedule(ContinuousInferenceSchedule): @@ -221,6 +256,8 @@ def __init__( self, nsteps: int, min_t: Float = 0, + padding: Float = 0, + dilation: Float = 0, p1: Float = 1.0, direction: TimeDirection = TimeDirection.UNIFIED, device: Union[str, torch.device] = "cpu", @@ -230,41 +267,51 @@ def __init__( Args: nsteps (int): Number of time steps. min_t (Float): minimum time value defaults to 0. + padding (Float): padding time value defaults to 0. + dilation (Float): dilation time value defaults to 0 ie the number of replicates. p1 (Float): Power parameter defaults to 1.0. direction (Optional[str]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). device (Optional[str]): Device to place the schedule on (default is "cpu"). """ - super().__init__(nsteps, min_t, direction, device) + super().__init__(nsteps, min_t, padding, dilation, direction, device) self.p1 = p1 def generate_schedule( self, nsteps: Optional[int] = None, - full: Bool = False, device: Optional[Union[str, torch.device]] = None, ) -> Tensor: """Generate the power time schedule as a tensor. Args: nsteps (Optional[int]): Number of time steps. If None uses the value from initialization. - full (Bool): Whether to return the full scheudle defaults to False. device (Optional[str]): Device to place the schedule on (default is "cpu"). Returns: Tensor: A tensor of time steps. + Tensor: A tensor of time steps. """ if device is None: device = self.device if nsteps is None: nsteps = self.nsteps + nsteps -= self.padding + dilation = self.dilation + 1 + if dilation > 1: + if nsteps % dilation != 0: + raise ValueError(f"nsteps ({nsteps}) is not divisible by dilation + 1 ({dilation})") + nsteps = int(nsteps / dilation) schedule = torch.linspace(0, 1, nsteps + 1).to(device=device) ** self.p1 if self.min_t > 0: schedule = torch.clamp(schedule, min=self.min_t) + schedule = schedule[:-1] + if dilation > 1: + schedule = schedule.repeat_interleave(dilation) + if self.padding > 0: + schedule = torch.cat((schedule, torch.ones(self.padding, device=device))) if self.direction == TimeDirection.DIFFUSION: schedule = 1 - schedule # schedule.flip(0) - if full: - return schedule - return schedule[:-1] + return schedule class LogInferenceSchedule(ContinuousInferenceSchedule): @@ -274,6 +321,8 @@ def __init__( self, nsteps: int, min_t: Float = 0, + padding: Float = 0, + dilation: Float = 0, p1: Float = 2, direction: TimeDirection = TimeDirection.UNIFIED, device: Union[str, torch.device] = "cpu", @@ -283,11 +332,13 @@ def __init__( Args: nsteps (int): Number of time steps. min_t (Float): minimum time value defaults to 0. + padding (Float): padding time value defaults to 0. + dilation (Float): dilation time value defaults to 0 ie the number of replicates. p1 (Float): log space parameter defaults to 2.0. direction (Optional[str]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). device (Optional[str]): Device to place the schedule on (default is "cpu"). """ - super().__init__(nsteps, min_t, direction, device) + super().__init__(nsteps, min_t, padding, dilation, direction, device) if p1 is None: raise ValueError("p1 cannot be None for the log schedule") if p1 <= 0: @@ -297,27 +348,34 @@ def __init__( def generate_schedule( self, nsteps: Optional[int] = None, - full: Bool = False, device: Optional[Union[str, torch.device]] = None, ) -> Tensor: """Generate the log time schedule as a tensor. Args: nsteps (Optional[int]): Number of time steps. If None uses the value from initialization. - full (Bool): Whether to return the full scheudle defaults to False. device (Optional[str]): Device to place the schedule on (default is "cpu"). """ if device is None: device = self.device if nsteps is None: nsteps = self.nsteps + nsteps -= self.padding + dilation = self.dilation + 1 + if dilation > 1: + if nsteps % dilation != 0: + raise ValueError(f"nsteps ({nsteps}) is not divisible by dilation + 1 ({dilation})") + nsteps = int(nsteps / self.dilation) t = 1.0 - torch.logspace(-self.p1, 0, nsteps + 1).flip(0).to(device=device) t = t - torch.min(t) schedule = t / torch.max(t) if self.min_t > 0: schedule = torch.clamp(schedule, min=self.min_t) + schedule = schedule[:-1] + if dilation > 1: + schedule = schedule.repeat_interleave(dilation) + if self.padding > 0: + schedule = torch.cat((schedule, torch.ones(self.padding, device=device))) if self.direction == TimeDirection.DIFFUSION: schedule = 1 - schedule # schedule.flip(0) - if full: - return schedule - return schedule[:-1] + return schedule diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_noise_transforms.py b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_noise_transforms.py index 5f9b9dd077..d40320e722 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_noise_transforms.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_noise_transforms.py @@ -18,8 +18,9 @@ from typing import Optional, Union import torch -from torch import Tensor from jaxtyping import Float +from torch import Tensor + from bionemo.moco.interpolants.base_interpolant import string_to_enum from bionemo.moco.schedules.utils import TimeDirection diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_snr_transforms.py b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_snr_transforms.py index ad0b63191f..b3b001b5ec 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_snr_transforms.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/continuous_snr_transforms.py @@ -19,8 +19,9 @@ from typing import Callable, Optional, Tuple, Union import torch -from torch import Tensor from jaxtyping import Float +from torch import Tensor + from bionemo.moco.interpolants.base_interpolant import string_to_enum from bionemo.moco.schedules.utils import TimeDirection diff --git a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/discrete_noise_schedules.py b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/discrete_noise_schedules.py index caa355348d..bc8cfc2cd8 100644 --- a/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/discrete_noise_schedules.py +++ b/sub-packages/bionemo-moco/src/bionemo/moco/schedules/noise/discrete_noise_schedules.py @@ -18,8 +18,9 @@ from typing import Optional, Union import torch -from torch import Tensor from jaxtyping import Float +from torch import Tensor + from bionemo.moco.interpolants.base_interpolant import string_to_enum from bionemo.moco.schedules.utils import TimeDirection @@ -58,9 +59,7 @@ def generate_schedule( return schedule @abstractmethod - def _generate_schedule( - self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu" - ) -> Tensor: + def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor: """Generate the noise schedule tensor. Args: @@ -107,9 +106,7 @@ def __init__(self, nsteps: int, nu: Float = 1.0, s: Float = 0.008): self.nu = nu self.s = s - def _generate_schedule( - self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu" - ) -> Tensor: + def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor: """Generate the cosine noise schedule. Args: @@ -163,9 +160,7 @@ def __init__(self, nsteps: int, beta_start: Float = 1e-4, beta_end: Float = 0.02 self.beta_start = beta_start self.beta_end = beta_end - def _generate_schedule( - self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu" - ) -> Tensor: + def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor: """Generate the cosine noise schedule. Args: diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/continuous/test_gaussian.py b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/continuous/test_gaussian.py index 0366ec104c..113aa9e0fb 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/continuous/test_gaussian.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/continuous/test_gaussian.py @@ -15,6 +15,7 @@ import torch + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_custom.py b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_custom.py index 5e2e6bfe7f..24c7191291 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_custom.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_custom.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.custom import DiscreteCustomPrior diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_mask.py b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_mask.py index 400bbd6305..d182745413 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_mask.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_mask.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_uniform.py b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_uniform.py index ab078d2823..fbfb28cd47 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_uniform.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/prior/discrete/test_uniform.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.uniform import DiscreteUniformPrior diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/time/test_time_distribution.py b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/time/test_time_distribution.py index 8909d848a1..15404806b3 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/time/test_time_distribution.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/distributions/time/test_time_distribution.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.time.beta import BetaTimeDistribution from bionemo.moco.distributions.time.distribution import MixTimeDistribution from bionemo.moco.distributions.time.logit_normal import LogitNormalTimeDistribution diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_continuous_flow_matching.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_continuous_flow_matching.py index 135fc3350e..d3cbb0e16c 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_continuous_flow_matching.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_continuous_flow_matching.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn.functional as F + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.base_interpolant import PredictionType diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_optimal_transport.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_optimal_transport.py index 76973666f4..9c594803ed 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_optimal_transport.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_optimal_transport.py @@ -15,6 +15,7 @@ import pytest import torch + from bionemo.moco.interpolants.continuous_time.continuous.optimal_transport import OTSampler @@ -80,7 +81,6 @@ def exact_ot_sampler(): return ot_sampler - @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("sampler", ["exact_ot_sampler"]) @pytest.mark.parametrize("data", ["toy_data", "toy_masked_data"]) diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_vdm.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_vdm.py index 967ac0ace1..b76ace8a9a 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_vdm.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/continuous/test_vdm.py @@ -19,6 +19,7 @@ import pytest import torch import torch.nn.functional as F + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.continuous_time.continuous.vdm import VDM diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_discrete_flow_matching.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_discrete_flow_matching.py index 025d680137..793b8c3bbd 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_discrete_flow_matching.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_discrete_flow_matching.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior from bionemo.moco.distributions.prior.discrete.uniform import DiscreteUniformPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_mdlm.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_mdlm.py index 1b6014d413..4f9ae56bdd 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_mdlm.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/continuous_time/discrete/test_mdlm.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/continuous/test_ddpm.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/continuous/test_ddpm.py index 63a77d5cd5..15001ca382 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/continuous/test_ddpm.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/continuous/test_ddpm.py @@ -17,6 +17,7 @@ import pytest import torch import torch.nn.functional as F + from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/discrete/test_d3pm.py b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/discrete/test_d3pm.py index 5fb09c86f8..4041553c63 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/discrete/test_d3pm.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/interpolants/discrete_time/discrete/test_d3pm.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.distributions.prior.discrete.uniform import DiscreteUniformPrior from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.discrete_time.discrete.d3pm import D3PM diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/noise/test_discrete_noise_schedule.py b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/noise/test_discrete_noise_schedule.py index 64b03e9394..7ddcc7adfd 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/noise/test_discrete_noise_schedule.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/noise/test_discrete_noise_schedule.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule from bionemo.moco.schedules.utils import TimeDirection diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_noise_transforms.py b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_noise_transforms.py index 7db5ab88da..f495fe8142 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_noise_transforms.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_noise_transforms.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco import TimeDirection from bionemo.moco.schedules.noise.continuous_noise_transforms import ( CosineExpNoiseTransform, diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_snr_transforms.py b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_snr_transforms.py index 9976654838..d9633a2eee 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_snr_transforms.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_continuous_snr_transforms.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.schedules.noise.continuous_snr_transforms import ( CosineSNRTransform, LinearLogInterpolatedSNRTransform, diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_discrete_noise_schedules.py b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_discrete_noise_schedules.py index 0edc066e64..9eb048da0f 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_discrete_noise_schedules.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_discrete_noise_schedules.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.schedules.discrete_noise_schedules import DiscreteCosineNoiseSchedule, TimeDirection diff --git a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_infernece_schedule.py b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_inference_schedules.py similarity index 79% rename from sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_infernece_schedule.py rename to sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_inference_schedules.py index 9c8365bee6..e13068396d 100644 --- a/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_infernece_schedule.py +++ b/sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_inference_schedules.py @@ -16,6 +16,7 @@ import pytest import torch + from bionemo.moco.schedules.inference_time_schedules import ( DiscreteLinearInferenceSchedule, LinearInferenceSchedule, @@ -122,10 +123,40 @@ def test_discrete_uniform_dt(timesteps, device, direction): assert dt.shape == (timesteps,) # Check if schedule is on the correct device assert schedule.device.type == device - # Additional checks specific to DiscreteUniformInferenceSchedule assert torch.all(dt == torch.full((timesteps,), 1 / timesteps, device=device)) if direction == TimeDirection.UNIFIED: assert schedule[0] < schedule[-1] else: assert schedule[0] > schedule[-1] + + +@pytest.mark.parametrize("timesteps", [10, 20]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("direction", [TimeDirection.UNIFIED, TimeDirection.DIFFUSION]) +@pytest.mark.parametrize("padding", [0, 2]) +@pytest.mark.parametrize("dilation", [0, 1]) +def test_uniform_dt_padding_dilation(timesteps, device, direction, padding, dilation): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + scheduler = LinearInferenceSchedule(timesteps, padding=padding, dilation=dilation, direction=direction) + dt = scheduler.discretize(device=device) + schedule = scheduler.generate_schedule(device=device) + + # Check if all dt's are equal to 1/timesteps + assert dt.device.type == device + + # Check if schedule has the correct shape + assert schedule.shape == (timesteps,) + # Check if dt has the correct shape + assert dt.shape == (timesteps,) + # Check if schedule is on the correct device + assert schedule.device.type == device + if direction == TimeDirection.UNIFIED: + assert schedule[0] < schedule[-1] + for i in range(padding): + assert schedule[-1 * (i + 1)] == 1.0 + else: + assert schedule[0] > schedule[-1] + for i in range(padding): + assert schedule[-1 * (i + 1)] == 0