Skip to content

Commit

Permalink
format fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nvdreidenbach committed Dec 24, 2024
1 parent 088e5d9 commit 90103f4
Show file tree
Hide file tree
Showing 42 changed files with 192 additions and 75 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@
"filename": "pyproject.toml",
"hashed_secret": "79670e9c9d1c7ea5b81a96a2053d81437712c78e",
"is_verified": false,
"line_number": 47
"line_number": 48
}
]
},
"generated_at": "2024-11-01T22:26:03Z"
"generated_at": "2024-12-24T16:57:17Z"
}
2 changes: 1 addition & 1 deletion LICENSE/third_party.txt
Original file line number Diff line number Diff line change
Expand Up @@ -844,4 +844,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0
1.0
2 changes: 0 additions & 2 deletions sub-packages/bionemo-moco/documentation.md
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
# Table of Contents


16 changes: 16 additions & 0 deletions sub-packages/bionemo-moco/scripts/clean_documentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

import torch
from jaxtyping import Bool, Float
from torch import Tensor

from bionemo.moco.distributions.prior.continuous.utils import remove_center_of_mass
from bionemo.moco.distributions.prior.distribution import PriorDistribution
from torch import Tensor


class GaussianPrior(PriorDistribution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

from typing import Optional

from jaxtyping import Float
from torch import Tensor


def remove_center_of_mass(
data: Float[Tensor, "* N D"], mask: Optional[Float[Tensor, "* N D"]] = None
) -> Float[Tensor, "* N D"]:
def remove_center_of_mass(data: Tensor, mask: Optional[Tensor] = None) -> Tensor:
"""Calculates the center of mass (CoM) of the given data.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from typing import Optional, Tuple, Union

import torch
from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution
from torch import Tensor

from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution


class DiscreteCustomPrior(DiscretePriorDistribution):
"""A subclass representing a discrete custom prior distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from typing import Optional, Tuple, Union

import torch
from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution
from torch import Tensor

from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution


class DiscreteMaskedPrior(DiscretePriorDistribution):
"""A subclass representing a Discrete Masked prior distribution."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from typing import Optional, Tuple, Union

import torch
from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution
from torch import Tensor

from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution


class DiscreteUniformPrior(DiscretePriorDistribution):
"""A subclass representing a discrete uniform prior distribution."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from jaxtyping import Bool, Float

from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.distributions.time.utils import float_time_to_index

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from jaxtyping import Bool, Float

from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.distributions.time.utils import float_time_to_index

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from jaxtyping import Bool, Float

from bionemo.moco.distributions.time.distribution import TimeDistribution


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from typing import Optional, Type, TypeVar, Union

import torch
from jaxtyping import Bool
from torch import Tensor
from jaxtyping import Bool, Float


from bionemo.moco.distributions.prior.distribution import PriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
Expand Down Expand Up @@ -219,9 +218,7 @@ def to_device(self, device: str):
setattr(self, attr_name, getattr(self, attr_name).to(device))
return self

def clean_mask_center(
self, data: Float[Tensor, "* N D"], mask: Optional[Float[Tensor, "* N"]] = None, center: Bool = False
) -> Float[Tensor, "* N D"]:
def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor:
"""Returns a clean tensor that has been masked and/or centered based on the function arguments.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import torch
import torch.nn as nn
from jaxtyping import Bool, Float
from torch import Tensor

from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
from bionemo.moco.distributions.prior.distribution import PriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.interpolants.base_interpolant import Interpolant, PredictionType, pad_like, string_to_enum
from bionemo.moco.interpolants.continuous_time.continuous.optimal_transport import OTSampler
from torch import Tensor


class OptimalTransportType(Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Bool, Float
from torch import Tensor

from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
from bionemo.moco.distributions.prior.distribution import PriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.interpolants.base_interpolant import Interpolant, PredictionType, pad_like, string_to_enum
from bionemo.moco.schedules.noise.continuous_snr_transforms import ContinuousSNRTransform
from torch import Tensor


class VDM(Interpolant):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Bool, Float
from torch import Tensor

from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.interpolants.base_interpolant import Interpolant, pad_like
from torch import Tensor


class DiscreteFlowMatcher(Interpolant):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from typing import Optional

import torch
from torch import Tensor

from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.interpolants.base_interpolant import Interpolant, pad_like
from bionemo.moco.schedules.noise.continuous_noise_transforms import ContinuousExpNoiseTransform
from torch import Tensor


class MDLM(Interpolant):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from jaxtyping import Bool, Float
from torch import Tensor

from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
from bionemo.moco.distributions.prior.distribution import PriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
from typing import Optional, Tuple

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor

from bionemo.moco.distributions.prior.distribution import DiscretePriorDistribution
from bionemo.moco.distributions.time.distribution import TimeDistribution
from bionemo.moco.interpolants.base_interpolant import Interpolant
from bionemo.moco.interpolants.discrete_time.utils import safe_index
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteNoiseSchedule



class D3PM(Interpolant):
"""A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant."""

Expand Down Expand Up @@ -330,7 +330,7 @@ def loss(
if aggregate:
loss = loss.mean()
if vb_scale > 0:
loss += vb_loss.mean() # type: ignore
loss += vb_loss.mean() # type: ignore
return loss
return loss, vb_loss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing import Optional, Union

import torch
from torch import Tensor
from jaxtyping import Float
from torch import Tensor

from bionemo.moco.interpolants.base_interpolant import string_to_enum
from bionemo.moco.schedules.utils import TimeDirection
Expand Down Expand Up @@ -64,9 +64,7 @@ def generate_schedule(
return schedule

@abstractmethod
def _generate_schedule(
self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu"
) -> Tensor:
def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor:
"""Generate the time schedule as a list.
Args:
Expand Down Expand Up @@ -115,9 +113,7 @@ def __init__(self, nsteps: int, nu: Float = 1.0, s: Float = 0.008):
self.nu = nu
self.s = s

def _generate_schedule(
self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu"
) -> Tensor:
def _generate_schedule(self, nsteps: Optional[int] = None, device: Union[str, torch.device] = "cpu") -> Tensor:
"""Generate the cosine noise schedule.
Args:
Expand Down
Loading

0 comments on commit 90103f4

Please sign in to comment.