From 0d71fa0612c3f394ce751c35748bae6a8e0d74d7 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 23 Jul 2024 16:27:22 +0000 Subject: [PATCH 01/70] BugFix: parser of docker version string the previous parser loops in the build number and os version, which aren't needed and cause bugs --- launch.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launch.sh b/launch.sh index 69a1aff00e..34d9e5855d 100755 --- a/launch.sh +++ b/launch.sh @@ -103,7 +103,7 @@ build() { } # Check Docker version - docker_version=$(docker --version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + docker_version=$(docker --version | awk -F'[, ]' '{print $3}') required_docker_version="23.0.1" if ! version_ge "$docker_version" "$required_docker_version"; then @@ -112,7 +112,7 @@ build() { fi # Check Buildx version - buildx_version=$(docker buildx version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + buildx_version=$(docker buildx version | awk '{print $2}') required_buildx_version="0.10.2" if ! version_ge "$buildx_version" "$required_buildx_version"; then From 1d30c4388de2fb23813379e14bc10a46f527df99 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 29 Jul 2024 22:07:15 +0000 Subject: [PATCH 02/70] Enhancement: migrate diffdock utils modules from v1 these are copied as it is --- .../molecule/diffdock/utils/diffusion.py | 146 +++++++++++++++ .../model/molecule/diffdock/utils/geometry.py | 125 +++++++++++++ .../model/molecule/diffdock/utils/so3.py | 173 ++++++++++++++++++ .../model/molecule/diffdock/utils/torsion.py | 109 +++++++++++ .../model/molecule/diffdock/utils/torus.py | 103 +++++++++++ 5 files changed, 656 insertions(+) create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py new file mode 100644 index 0000000000..135c5d64fb --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from bionemo.contrib.model.molecule.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch +from bionemo.contrib.model.molecule.diffdock.utils.torsion import modify_conformer_torsion_angles + + +def t_to_sigma(t_tr, t_rot, t_tor, cfg): + tr_sigma = cfg.diffusion.tr_sigma_min ** (1 - t_tr) * cfg.diffusion.tr_sigma_max**t_tr + rot_sigma = cfg.diffusion.rot_sigma_min ** (1 - t_rot) * cfg.diffusion.rot_sigma_max**t_rot + tor_sigma = cfg.diffusion.tor_sigma_min ** (1 - t_tor) * cfg.diffusion.tor_sigma_max**t_tor + return tr_sigma, rot_sigma, tor_sigma + + +def modify_conformer(data, tr_update, rot_update, torsion_updates): + lig_center = torch.mean(data["ligand"].pos, dim=0, keepdim=True) + rot_mat = axis_angle_to_matrix(rot_update.squeeze()) + rigid_new_pos = (data["ligand"].pos - lig_center) @ rot_mat.T + tr_update + lig_center + + if torsion_updates is not None: + flexible_new_pos = modify_conformer_torsion_angles( + rigid_new_pos, + data["ligand", "ligand"].edge_index.T[data["ligand"].edge_mask], + ( + data["ligand"].mask_rotate + if isinstance(data["ligand"].mask_rotate, np.ndarray) + else data["ligand"].mask_rotate[0] + ), + torsion_updates, + ).to(rigid_new_pos.device) + R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T) + aligned_flexible_pos = flexible_new_pos @ R.T + t.T + data["ligand"].pos = aligned_flexible_pos + else: + data["ligand"].pos = rigid_new_pos + return data + + +def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000): + """from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py""" + assert len(timesteps.shape) == 1 + half_dim = embedding_dim // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode="constant") + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels. + from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 + """ + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return emb + + +def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000): + if embedding_type == "sinusoidal": + + def emb_func(x): + return sinusoidal_embedding(embedding_scale * x, embedding_dim) + + elif embedding_type == "fourier": + emb_func = GaussianFourierProjection(embedding_size=embedding_dim, scale=embedding_scale) + else: + raise NotImplementedError + return emb_func + + +class timestep_embedding(nn.Module): + def __init__(self, embedding_type, embedding_dim, embedding_scale=10000): + super(timestep_embedding, self).__init__() + self.embedding_type = embedding_type + self.embedding_dim = embedding_dim + self.embedding_scale = embedding_scale + self.emb_func = get_timestep_embedding(embedding_type, embedding_dim, embedding_scale) + + def forward(self, *args, **kwargs): + return self.emb_func(*args, **kwargs) + + def __getstate__(self): + return { + "embedding_type": self.embedding_type, + "embedding_dim": self.embedding_dim, + "embedding_scale": self.embedding_scale, + } + + def __setstate__(self, d): + super(timestep_embedding, self).__init__() + self.embedding_type = d["embedding_type"] + self.embedding_dim = d["embedding_dim"] + self.embedding_scale = d["embedding_scale"] + self.emb_func = get_timestep_embedding(**d) + + +def get_t_schedule(denoising_inference_steps): + return np.linspace(1, 0, denoising_inference_steps + 1)[:-1] + + +def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device): + complex_graphs["ligand"].node_t = { + "tr": t_tr * torch.ones(complex_graphs["ligand"].num_nodes).to(device), + "rot": t_rot * torch.ones(complex_graphs["ligand"].num_nodes).to(device), + "tor": t_tor * torch.ones(complex_graphs["ligand"].num_nodes).to(device), + } + complex_graphs["receptor"].node_t = { + "tr": t_tr * torch.ones(complex_graphs["receptor"].num_nodes).to(device), + "rot": t_rot * torch.ones(complex_graphs["receptor"].num_nodes).to(device), + "tor": t_tor * torch.ones(complex_graphs["receptor"].num_nodes).to(device), + } + complex_graphs.complex_t = { + "tr": t_tr * torch.ones(batchsize).to(device), + "rot": t_rot * torch.ones(batchsize).to(device), + "tor": t_tor * torch.ones(batchsize).to(device), + } + if all_atoms: + complex_graphs["atom"].node_t = { + "tr": t_tr * torch.ones(complex_graphs["atom"].num_nodes).to(device), + "rot": t_rot * torch.ones(complex_graphs["atom"].num_nodes).to(device), + "tor": t_tor * torch.ones(complex_graphs["atom"].num_nodes).to(device), + } diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py new file mode 100644 index 0000000000..648045d261 --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math + +import torch + + +def quaternion_to_matrix(quaternions): + """ + From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def axis_angle_to_quaternion(axis_angle): + """ + From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1) + return quaternions + + +def axis_angle_to_matrix(axis_angle): + """ + From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def rigid_transform_Kabsch_3D_torch(A, B): + # R = 3x3 rotation matrix, t = 3x1 column vector + # This already takes residue identity into account. + + assert A.shape[1] == B.shape[1] + num_rows, num_cols = A.shape + if num_rows != 3: + raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") + num_rows, num_cols = B.shape + if num_rows != 3: + raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") + + # find mean column wise: 3 x 1 + centroid_A = torch.mean(A, axis=1, keepdims=True) + centroid_B = torch.mean(B, axis=1, keepdims=True) + + # subtract mean + Am = A - centroid_A + Bm = B - centroid_B + + H = Am @ Bm.T + + # find rotation + U, S, Vt = torch.linalg.svd(H) + + R = Vt.T @ U.T + # special reflection case + if torch.linalg.det(R) < 0: + SS = torch.diag(torch.tensor([1.0, 1.0, -1.0], device=A.device)) + R = (Vt.T @ SS) @ U.T + assert math.fabs(torch.linalg.det(R) - 1) < 3e-3 # note I had to change this error bound to be higher + + t = -R @ centroid_A + centroid_B + return R, t diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py new file mode 100644 index 0000000000..7fa0da0a3e --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + + +import os + +import numpy as np +import torch +from scipy.spatial.transform import Rotation + + +package_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + +MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000 +X_N = 2000 + +omegas = np.linspace(0, np.pi, X_N + 1)[1:] + +# TODO generating these arrays is super slow, we should vectorize this + + +def _compose(r1, r2): # R1 @ R2 but for Euler vecs + return Rotation.from_matrix( + Rotation.from_rotvec(r1).as_matrix() @ Rotation.from_rotvec(r2).as_matrix() + ).as_rotvec() + + +def _expansion(omega, eps, L=2000): # the summation term only + p = 0 + for l in range(L): + p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) + return p + + +def _expansion_vectorized(omega, eps, L=2000): + l = np.arange(L).reshape((-1, 1)) + omega = omega.reshape((1, -1)) + eps = eps.reshape((1, -1)) + + p1 = (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) + p2 = np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) + p = np.matmul(p2.T, p1).T + return p + + +def _density(expansion, omega, marginal=True): # if marginal, density over [0, pi], else over SO(3) + if marginal: + return expansion * (1 - np.cos(omega)) / np.pi + else: + return expansion / 8 / np.pi**2 # the constant factor doesn't affect any actual calculations though + + +def _score(exp, omega, eps, L=2000): # score of density over SO(3) + dSigma = 0 + for l in range(L): + hi = np.sin(omega * (l + 1 / 2)) + dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) + lo = np.sin(omega / 2) + dlo = 1 / 2 * np.cos(omega / 2) + dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo**2 + return dSigma / exp + + +def _score_vectorized(exp, omega, eps, L=2000): # score of density over SO(3) + dSigma = 0 + l = np.arange(L).reshape((-1, 1)) + omega = omega.reshape((1, -1)) + eps = eps.reshape((1, -1)) + + hi = np.sin(omega * (l + 1 / 2)) + dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) + lo = np.sin(omega / 2) + dlo = 1 / 2 * np.cos(omega / 2) + dSigma1 = (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) + dSigma2 = (lo * dhi - hi * dlo) / lo**2 + dSigma = np.matmul(dSigma2.T, dSigma1).T + return dSigma / exp + + +def _score_small_eps(omega, eps): + # formula for f(omega, eps) in eq (5) https://openreview.net/pdf?id=jHA-yCyBGb + # score = d(log(f(omega, eps^2)) / d omega + # for our range of omegas, this approximation works well for eps up to ~0.7 + # note that for numerical stability it is important to combine + # exp(pi*omega/eps) * exp(-pi**2/eps) into exp(pi*(omega-pi)/eps) + + x = omega.reshape((1, -1)) + a = eps.reshape((-1, 1)) ** 2 + + return ( + -0.5 * x / a + + ( + 1 + + -np.exp(np.pi * (x - np.pi) / a) + + -np.exp(-np.pi * (x + np.pi) / a) + + -(np.pi * (x - 2 * np.pi) / a) * np.exp(np.pi * (x - np.pi) / a) + + np.pi * (x + 2 * np.pi) / a * np.exp(-np.pi * (x + np.pi) / a) + ) + / (x + -(x - 2 * np.pi) * np.exp(np.pi * (x - np.pi) / a) + (x + 2 * np.pi) * np.exp(-np.pi * (x + np.pi) / a)) + - 0.5 * np.cos(x / 2) / np.sin(x / 2) + ) + + +if os.path.exists(os.path.join(package_path, ".so3.npz")): + so3 = np.load(os.path.join(package_path, ".so3.npz")) + _omegas_array = so3["_omegas_array"] + _cdf_vals = so3["_cdf_vals"] + _score_norms = so3["_score_norms"] + _exp_score_norms = so3["_exp_score_norms"] +else: + _eps_array = (10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS)).astype(np.float128) + _omegas_array = np.linspace(0, np.pi, X_N + 1)[1:].astype(np.float128) + + _exp_vals = _expansion_vectorized(_omegas_array, _eps_array) + _pdf_vals = _density(_exp_vals, _omegas_array, marginal=True) + _cdf_vals = _pdf_vals.cumsum(1) / X_N * np.pi + _score_norms = np.zeros((N_EPS, X_N)) + _small_eps_idx = _eps_array < 0.5 + _score_norms[_small_eps_idx] = _score_small_eps(_omegas_array, _eps_array[_small_eps_idx]) + _score_norms[~_small_eps_idx] = _score_vectorized( + _exp_vals[~_small_eps_idx], _omegas_array, _eps_array[~_small_eps_idx] + ) + + _exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi) + + _omegas_array = _omegas_array.astype(np.float64) + _cdf_vals = _cdf_vals.astype(np.float64) + _score_norms = _score_norms.astype(np.float64) + _exp_score_norms = _exp_score_norms.astype(np.float64) + + np.savez( + os.path.join(package_path, ".so3.npz"), + _omegas_array=_omegas_array, + _cdf_vals=_cdf_vals, + _score_norms=_score_norms, + _exp_score_norms=_exp_score_norms, + ) + + +def sample(eps): + eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS + eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) + x = np.random.rand() + return np.interp(x, _cdf_vals[eps_idx], _omegas_array) + + +def sample_vec(eps): + x = np.random.randn(3) + x /= np.linalg.norm(x) + return x * sample(eps) + + +def score_vec(eps, vec): + eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS + eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) + + om = np.linalg.norm(vec) + return np.interp(om, _omegas_array, _score_norms[eps_idx]) * vec / om + + +def score_norm(eps): + device = eps.device + eps = eps.cpu().numpy() + eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS + eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) + return torch.from_numpy(_exp_score_norms[eps_idx]).to(device=device, dtype=torch.float) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py new file mode 100644 index 0000000000..b4c7e3692f --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import copy + +import networkx as nx +import numpy as np +import torch +from scipy.spatial.transform import Rotation as R +from torch_geometric.data import Data +from torch_geometric.utils import to_networkx + + +""" + Preprocessing and computation for torsional updates to conformers +""" + + +def get_transformation_mask(pyg_data): + G = to_networkx(pyg_data.to_homogeneous(), to_undirected=False) + to_rotate = [] + edges = pyg_data["ligand", "ligand"].edge_index.T.numpy() + for i in range(0, edges.shape[0], 2): + assert edges[i, 0] == edges[i + 1, 1] + + G2 = G.to_undirected() + G2.remove_edge(*edges[i]) + if not nx.is_connected(G2): + l = list(sorted(nx.connected_components(G2), key=len)[0]) + if len(l) > 1: + if edges[i, 0] in l: + to_rotate.append([]) + to_rotate.append(l) + else: + to_rotate.append(l) + to_rotate.append([]) + continue + to_rotate.append([]) + to_rotate.append([]) + + mask_edges = np.asarray([0 if len(l) == 0 else 1 for l in to_rotate], dtype=bool) + mask_rotate = np.zeros((np.sum(mask_edges), len(G.nodes())), dtype=bool) + idx = 0 + for i in range(len(G.edges())): + if mask_edges[i]: + mask_rotate[idx][np.asarray(to_rotate[i], dtype=int)] = True + idx += 1 + + return mask_edges, mask_rotate + + +def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False): + pos = copy.deepcopy(pos) + if type(pos) != np.ndarray: + pos = pos.cpu().numpy() + + for idx_edge, e in enumerate(edge_index.cpu().numpy()): + if torsion_updates[idx_edge] == 0: + continue + u, v = e[0], e[1] + + # check if need to reverse the edge, v should be connected to the part that gets rotated + assert not mask_rotate[idx_edge, u] + assert mask_rotate[idx_edge, v] + + rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards + rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge! + rot_mat = R.from_rotvec(rot_vec).as_matrix() + + pos[mask_rotate[idx_edge]] = (pos[mask_rotate[idx_edge]] - pos[v]) @ rot_mat.T + pos[v] + + if not as_numpy: + pos = torch.from_numpy(pos.astype(np.float32)) + return pos + + +def perturb_batch(data, torsion_updates, split=False, return_updates=False): + if type(data) is Data: + return modify_conformer_torsion_angles( + data.pos, data.edge_index.T[data.edge_mask], data.mask_rotate, torsion_updates + ) + pos_new = [] if split else copy.deepcopy(data.pos) + edges_of_interest = data.edge_index.T[data.edge_mask] + idx_node = 0 + idx_edges = 0 + torsion_update_list = [] + for i, mask_rotate in enumerate(data.mask_rotate): + pos = data.pos[idx_node : idx_node + mask_rotate.shape[1]] + edges = edges_of_interest[idx_edges : idx_edges + mask_rotate.shape[0]] - idx_node + torsion_update = torsion_updates[idx_edges : idx_edges + mask_rotate.shape[0]] + torsion_update_list.append(torsion_update) + pos_new_ = modify_conformer_torsion_angles(pos, edges, mask_rotate, torsion_update) + if split: + pos_new.append(pos_new_) + else: + pos_new[idx_node : idx_node + mask_rotate.shape[1]] = pos_new_ + + idx_node += mask_rotate.shape[1] + idx_edges += mask_rotate.shape[0] + if return_updates: + return pos_new, torsion_update_list + return pos_new diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py new file mode 100644 index 0000000000..977cff937d --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os + +import numpy as np +import tqdm + + +""" + Preprocessing for the SO(2)/torus sampling and score computations, truncated infinite series are computed and then + cached to memory, therefore the precomputation is only run the first time the repository is run on a machine +""" + +package_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + + +def p(x, sigma, N=10): + p_ = 0 + for i in tqdm.trange(-N, N + 1): + p_ += np.exp(-((x + 2 * np.pi * i) ** 2) / 2 / sigma**2) + return p_ + + +def grad(x, sigma, N=10): + p_ = 0 + for i in tqdm.trange(-N, N + 1): + p_ += (x + 2 * np.pi * i) / sigma**2 * np.exp(-((x + 2 * np.pi * i) ** 2) / 2 / sigma**2) + return p_ + + +X_MIN, X_N = 1e-5, 5000 # relative to pi +SIGMA_MIN, SIGMA_MAX, SIGMA_N = 3e-3, 2, 5000 # relative to pi + +x = 10 ** np.linspace(np.log10(X_MIN), 0, X_N + 1) * np.pi +sigma = 10 ** np.linspace(np.log10(SIGMA_MIN), np.log10(SIGMA_MAX), SIGMA_N + 1) * np.pi + +if os.path.exists(os.path.join(package_path, ".torus.npz")): + torus = np.load(os.path.join(package_path, ".torus.npz")) + p_ = torus["p_"] + score_ = torus["score_"] +else: + p_ = p(x, sigma[:, None], N=100) + score_ = grad(x, sigma[:, None], N=100) / p_ + + np.savez(os.path.join(package_path, ".torus.npz"), p_=p_, score_=score_) + + +def score(x, sigma): + x = (x + np.pi) % (2 * np.pi) - np.pi + sign = np.sign(x) + x = np.log(np.abs(x) / np.pi) + x = (x - np.log(X_MIN)) / (0 - np.log(X_MIN)) * X_N + x = np.round(np.clip(x, 0, X_N)).astype(int) + sigma = np.log(sigma / np.pi) + sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N + sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) + return -sign * score_[sigma, x] + + +def p(x, sigma): + x = (x + np.pi) % (2 * np.pi) - np.pi + x = np.log(np.abs(x) / np.pi) + x = (x - np.log(X_MIN)) / (0 - np.log(X_MIN)) * X_N + x = np.round(np.clip(x, 0, X_N)).astype(int) + sigma = np.log(sigma / np.pi) + sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N + sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) + return p_[sigma, x] + + +def sample(sigma, seed=None): + if seed is None: + out = sigma * np.random.randn(*sigma.shape) + else: + rng = np.random.default_rng(seed) + out = sigma * rng.normal(size=sigma.shape) + out = (out + np.pi) % (2 * np.pi) - np.pi + return out + + +class TorusScoreNorm: + _score_norm = None + + def __init__(self, seed=None): + if TorusScoreNorm._score_norm is None: + _score_norm = score( + sample(sigma[None].repeat(10000, 0).flatten(), seed=seed), sigma[None].repeat(10000, 0).flatten() + ).reshape(10000, -1) + TorusScoreNorm._score_norm = (_score_norm**2).mean(0) + + def __call__(self, sigma): + sigma = np.log(sigma / np.pi) + sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N + sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) + return TorusScoreNorm._score_norm[sigma] From 9a7f82ea4f341f53cb4f466b7d4d0204875f41e8 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 30 Jul 2024 22:08:23 +0000 Subject: [PATCH 03/70] Enhancement: explicitly pass tr/rot/tor_sigma's to t_to_sigma --- .../contrib/model/molecule/diffdock/utils/diffusion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py index 135c5d64fb..9b080e6b33 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py @@ -19,10 +19,11 @@ from bionemo.contrib.model.molecule.diffdock.utils.torsion import modify_conformer_torsion_angles -def t_to_sigma(t_tr, t_rot, t_tor, cfg): - tr_sigma = cfg.diffusion.tr_sigma_min ** (1 - t_tr) * cfg.diffusion.tr_sigma_max**t_tr - rot_sigma = cfg.diffusion.rot_sigma_min ** (1 - t_rot) * cfg.diffusion.rot_sigma_max**t_rot - tor_sigma = cfg.diffusion.tor_sigma_min ** (1 - t_tor) * cfg.diffusion.tor_sigma_max**t_tor +def t_to_sigma(t_tr, t_rot, t_tor, tr_sigma_min, tr_sigma_max, rot_sigma_min, + rot_sigma_max, tor_sigma_min, tor_sigma_max): + tr_sigma = tr_sigma_min ** (1 - t_tr) * tr_sigma_max**t_tr + rot_sigma = rot_sigma_min ** (1 - t_rot) * rot_sigma_max**t_rot + tor_sigma = tor_sigma_min ** (1 - t_tor) * tor_sigma_max**t_tor return tr_sigma, rot_sigma, tor_sigma From 756a2e347b2dfccb2b8a8470fb12f6fb8b7f9b31 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 1 Aug 2024 03:03:39 +0000 Subject: [PATCH 04/70] Enhancement: add diffdock score model data module and tests --- .../data/molecule/diffdock/__init__.py | 14 + .../data/molecule/diffdock/datamodule.py | 326 ++++++++++++++++++ .../contrib/data/molecule/diffdock/utils.py | 241 +++++++++++++ .../tests/bionemo/contrib/data/conftest.py | 67 ++++ .../contrib/data/test_diffdock_datamodule.py | 100 ++++++ 5 files changed, 748 insertions(+) create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py create mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py create mode 100644 sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py create mode 100644 sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py new file mode 100644 index 0000000000..2890c34831 --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, auto +from functools import partial +import glob +import pickle +import random +from typing import Set, Optional, Tuple +import lightning as L +import torch +from torch_geometric.data.hetero_data import HeteroData +from torch_geometric.loader.dataloader import Collater +import webdataset as wds + +from bionemo.contrib.data.molecule.diffdock.utils import pickles_to_tars, NoiseTransform, SizeAwareBatching, estimate_size +from bionemo.contrib.model.molecule.diffdock.utils.diffusion import t_to_sigma + + +class Split(Enum): + train = auto() + val = auto() + test = auto() + + +class ScoreModelWDS(L.LightningDataModule): + + """lightning APIs to process score model data and setup dataset and + dataloader""" + + def __init__(self, dir_heterodata : str, suffix_heterodata : str, + prefix_dir_tars_wds : str, names_subset_train : Set[str], + names_subset_val : Set[str], local_batch_size : int, + global_batch_size : int, n_workers_dataloader : int, + tr_sigma_minmax : Tuple[float, float] + = (0.1, 19), rot_sigma_minmax : Tuple[float, float] = (0.03, + 1.55), + tor_sigma_minmax : Optional[Tuple[float, float]] = (0.0314, + 3.14), + is_all_atom : bool = False, apply_size_control : Tuple[bool, + bool, + bool] = + (True, False, False), pin_memory_dataloader : bool = True, + prefix_tars_wds : str = "heterographs", + n_tars_wds : Optional[int] = None, names_subset_test : + Optional[Set[str]] = None, seed_rng_shfl : int = 0): + """constructor + + Args: + dir_heterodata (str): input directory of PyG HeteroData pickled + files + suffix_heterodata (str): filename suffix of the input data in + dir_heterodata. This is also used as the key mapped to the + tarballed HeteroData object in the webdataset + prefix_dir_tars_wds (str): directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. + names_subset_train (Set[str]): list of complex names to be included + in the training data + names_subset_val (Set[str]): list of complex names to be included + in the validation data + local_batch_size (int): size of batch for each node + global_batch_size (int): size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes + n_workers_dataloader (int): number of data loading workers (passed + to pytorch dataloader) + seed_rng_shfl (int): seed to the random number generators used in + data loading time for shuffling + + Kwargs: + tr_sigma_minmax (Tuple[float, float]): min and max sigma for the + translational component during diffusion + rot_sigma_minmax (Tuple[float, float]): min and max sigma for the + rotational component during diffusion + tor_sigma_minmax (Optional[Tuple[float, float]]): min and max sigma + for the torsional component during diffusion + is_all_atom (bool): whether to treat the data as all-atom system + during noise transformation + apply_size_control(Tuple[bool, bool, bool]): whether to use + SizeAwareBatching for the respective train, val and test data + pin_memory_dataloader (bool): whether to use pin memory in pytorch + dataloader + prefix_tars_wds (str): name prefix to output webdataset tar files + n_tars_wds (int): attempt to create at least this number of webdataset shards + names_subset_test (Optional[Set[str]]): list of complex names to be included + in the test data + + + """ + super().__init__() + + self._dir_heterodata = dir_heterodata + self._suffix_heterodata = suffix_heterodata + self._n_tars_wds = n_tars_wds + self._prefix_dir_tars_wds = prefix_dir_tars_wds + self._prefix_tars_wds = prefix_tars_wds + self._names_subset_train = names_subset_train + self._names_subset_val = names_subset_val + self._names_subset_test = names_subset_test + + self._sizes = { + Split.train : len(self._names_subset_train), + Split.val : len(self._names_subset_val), + Split.test : len(self._names_subset_test) if + self._names_subset_test is not None else None, + } + + self._dirs_tars_wds = { + Split.train : f"{self._prefix_dir_tars_wds}train", + Split.val : f"{self._prefix_dir_tars_wds}val", + Split.test : f"{self._prefix_dir_tars_wds}test", + } + + self._tr_sigma_min, self._tr_sigma_max = tr_sigma_minmax + self._rot_sigma_min, self._rot_sigma_max = rot_sigma_minmax + self._tor_sigma_min, self._tor_sigma_max = (None, None) + self._no_torsion = True + if tor_sigma_minmax is not None: + self._tor_sigma_min, self._tor_sigma_max = tor_sigma_minmax + self._no_torsion = False + # TODO: the all-atom arg to set_time should be inferred from the + # complex_graph arg so we don't have to pass it all-the-way down + self._is_all_atom = is_all_atom + + self._local_batch_size = local_batch_size + self._global_batch_size = global_batch_size + self._use_dynamic_batch_size = { + Split.train : apply_size_control[0], + Split.val : apply_size_control[1], + Split.test : apply_size_control[2], + } + self._n_workers_dataloader = n_workers_dataloader + self._pin_memory_dataloader = pin_memory_dataloader + self._seed_rng_shfl = seed_rng_shfl + + + def _complex_graph_to_tar(self, complex_graph : HeteroData): + """map input complex graph to webdataset tar file conforming to its + format requirement + + Args: + complex_graph (HeteroData): input complex graph + + Returns: webdataset tar file segment (dict) + + """ + return { + "__key__": complex_graph.name.replace(".", "-"), + self._suffix_heterodata: pickle.dumps(complex_graph) + } + + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses + + Returns: None + """ + # create wds shards (tar files) for train set + pickles_to_tars(self._dir_heterodata, + self._suffix_heterodata, + self._names_subset_train, + self._dirs_tars_wds[Split.train], + self._prefix_tars_wds, + self._complex_graph_to_tar, + min_num_shards=self._n_tars_wds) + + # create wds shards (tar files) for val set + pickles_to_tars(self._dir_heterodata, + self._suffix_heterodata, + self._names_subset_val, + self._dirs_tars_wds[Split.val], + self._prefix_tars_wds, + self._complex_graph_to_tar, + min_num_shards=self._n_tars_wds) + + if self._names_subset_test is not None: + # create wds shards (tar files) for test set + pickles_to_tars(self._dir_heterodata, + self._suffix_heterodata, + self._names_subset_test, + self._dirs_tars_wds[Split.test], + self._prefix_tars_wds, + self._complex_graph_to_tar, + min_num_shards=self._n_tars_wds) + + + + def _setup_wds(self, split : Split) -> wds.WebDataset: + """setup webdataset and webloader. This is called by setup() + + Args: + split (Split): train, val or test split + + Returns: WebDataset + + """ + random.seed(self._seed_rng_shfl) + is_train = split == Split.train + urls = glob.glob( + f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") + dataset = ( + wds.WebDataset(urls, shardshuffle=is_train, + nodesplitter=wds.split_by_node) + .decode() + .extract_keys(f"*.{self._suffix_heterodata}") + ) + if is_train: + self._xform_train = ( + NoiseTransform(partial(t_to_sigma, + self._tr_sigma_min, self._tr_sigma_max, + self._rot_sigma_min, + self._rot_sigma_max, + self._tor_sigma_min, + self._tor_sigma_max), + self._no_torsion, self._is_all_atom)) + dataset = (dataset + .compose(partial(self._xform_train.apply_noise_iter, + keep_pos=(split == Split.val))) + ) + # sandwiched here to mirror the original DiffDock FW implementation + size = self._sizes[split] + # FIXME: remove this with_length since it's overriden later anyway + dataset = dataset.with_length(size) + if is_train: + dataset = dataset.shuffle(size=5000, + rng=random.Random(self._seed_rng_shfl)) + n_batches = ((size + self._global_batch_size - 1) + // self._global_batch_size) + if not self._use_dynamic_batch_size[split]: + dataset = ( + dataset.batched(self._local_batch_size, + collation_fn=Collater(dataset=[], + follow_batch=None, + exclude_keys=None)) + .with_epoch(n_batches) + .with_length(n_batches) + ) + else: + f_batching = SizeAwareBatching( + max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, + size_fn=estimate_size, + ) + dataset = dataset.compose(f_batching).with_length(n_batches) + if is_train: + dataset = dataset.select(lambda x: len(x) > 1) + + return dataset + + def setup(self, stage: str) -> None: + """This is called on all Lightning-managed nodes in a multi-node + training session + + + Args: + stage (str): "fit", "test" or "predict" + Returns: None + """ + if stage == "fit": + self._dataset_train = self._setup_wds(Split.train) + self._dataset_val = self._setup_wds(Split.val) + elif stage == "test": + self._dataset_test = self._setup_wds(Split.test) + else: + raise NotImplementedError("Data setup with stage = {stage}\ + is not implmented") + + def _setup_dataloader(self, dataset : wds.WebDataset) -> wds.WebLoader: + """wrap the input dataset into a WebLoader + + Args: + dataset (wds.WebDataset): input dataset object + + Returns: WebLoader object + + """ + if not hasattr(dataset, "__len__"): + raise RuntimeError("Input dataset object doesn't have length") + n_batches = len(dataset) + loader = wds.WebLoader(dataset, + num_workers=self._n_workers_dataloader, + pin_memory=self._pin_memory_dataloader, + collate_fn=lambda x: x[0], + ).with_length(n_batches).with_epoch(n_batches) + + # strange features required by nemo optimizer lr_scheduler + loader.dataset = dataset # seems like only length is used, webloader doesn't have this attr + loader.batch_size = self._local_batch_size + loader.drop_last = False + return loader + + + def train_dataloader(self) -> wds.WebLoader: + assert self._dataset_train is not None,\ + f"dataset for train has not been setup" + return self._setup_dataloader(self._dataset_train) + + + def val_dataloader(self) -> wds.WebLoader: + assert self._dataset_val is not None,\ + f"dataset for val has not been setup" + return self._setup_dataloader(self._dataset_val) + + + def test_dataloader(self) -> wds.WebLoader: + assert self._dataset_test is not None,\ + f"dataset for test has not been setup" + return self._setup_dataloader(self._dataset_test) + + + def predict_dataloader(self) -> wds.WebLoader: + raise NotImplementedError("predict dataloader not implemented") diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py new file mode 100644 index 0000000000..c0857ee17c --- /dev/null +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import random +from typing import Any, Callable, Generator, List, Optional +from copy import deepcopy + +from nemo.utils import logging +import torch +from torch_geometric.data import HeteroData +from torch_geometric.data.batch import Batch +from torch_geometric.loader.dataloader import Collater +from torch_geometric.transforms import BaseTransform +import numpy as np + +import webdataset as wds + +from bionemo.contrib.model.molecule.diffdock.utils.diffusion import modify_conformer, set_time +from bionemo.contrib.model.molecule.diffdock.utils import so3, torus + + +def pickles_to_tars( + dir_input: str, + input_suffix: str, + input_prefix_subset: set, + dir_output: str, + output_prefix: str, + func_output_data: Callable = lambda data: {"data": pickle.dumps(data)}, + min_num_shards: Optional[int] = None, +) -> None: + """Convert a subset of pickle files from a directory to Webdataset tar files + Input path and name pattern: + f"{dir_input}/{input_prefix_subset}.{input_suffix}" + Output path and name pattern: + f"{dir_output}/{output_prefix}-%06d.tar" + + Args: + dir_input (str): Input directory + input_suffix (str): Input pickle file name suffix + input_prefix_subset (set): Input subset of pickle files' prefix + dir_output (str): Output directory + output_prefix (str): Output tar file name prefix + func_output_data (Callable) : function that maps data to a dictionary + to be output in the tar files + min_num_shards (int) : create at least this number of tar files. + WebDataset has bugs when reading small number of tar files in a + multi-node lightening + DDP setting so this option can be used to + guarantee the tar file counts + + Returns: None + + """ + os.makedirs(dir_output, exist_ok=True) + wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") + maxsize = 1e8 + # Due to a Webdataset bug, number of shards should be >= number of workers + # (num. of gpus * num. of workers per gpu) + # TODO: this algorithm is not accurate enough because it doesn't take into + # account the block structure so I have to multiply the total_size with a + # small prefactor to purposely underestimate the size so that it ends up + # creating more tar files than min_num_shards + if min_num_shards is not None and min_num_shards > 1: + total_size = 0 + for name in input_prefix_subset: + try: + total_size += os.stat(os.path.join(dir_input, f"{name}.{input_suffix}")).st_size + except Exception: + continue + maxsize = min(total_size * 0.7 // min_num_shards, maxsize) + with wds.ShardWriter(wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777) as sink: + for name in input_prefix_subset: + try: + data = pickle.load(open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb")) + sample = func_output_data(data) + except ModuleNotFoundError as e: + logging.error(f"Dependency for parsing input pickle data not "\ + f"found: {e}") + raise e + except Exception as e: + logging.error(f"Failed to write {name} into tar files due to error {e}") + continue + + sink.write(sample) + + +def num_cross_edge_upper_bound_estimate(n1, n2, n3, n4): + terms = [[4.92, 'ligand_ligand'], + [0.0118, 'receptor_receptor'], + [0.0401, 'ligand', 'receptor_receptor']] + scale = 1.03 + tmpdict = {"ligand": n1, "ligand_ligand": n2, "receptor": n3, "receptor_receptor": n4} + num_edges = 0.0 + for term in terms: + tmp = term[0] + for k in term[1:]: + tmp *= tmpdict[k] + num_edges += tmp + num_edges *= scale + return num_edges + + +def estimate_memory_usage(data, num_cross_edges, use_bias=True): + # bias is from the memory of model, so when estimate the upper bound for size aware batch sampler, we don't need this + coeff_ligand_num_nodes = 2.9 + coeff_ligand_num_edges = 0.0 + coeff_receptor_num_nodes = 0.0 + coeff_receptor_num_edges = 0.11 + coeff_num_cross_edges = 0.25 + total_memory = ( + coeff_ligand_num_nodes * data["ligand"].num_nodes + + coeff_ligand_num_edges * data["ligand", "ligand"].num_edges + + coeff_receptor_num_nodes * data["receptor"].num_nodes + + coeff_receptor_num_edges * data["receptor", "receptor"].num_edges + + coeff_num_cross_edges * num_cross_edges + ) + if use_bias: + bias = 430.5 + return total_memory + bias + else: + return total_memory + +def estimate_size(g): + n1, n2, n3, n4 = ( + g["ligand"].num_nodes, + g["ligand", "ligand"].num_edges, + g["receptor"].num_nodes, + g["receptor", "receptor"].num_edges, + ) + # estimate the upper bound of the number of cross edges + # the number of cross edges roughly increases w.r.t. the diffusion step t (sampled from uniform(0,1)) + # the empirical formula here is from the polynomial fitting + # the scaling constant is to help remove the outliers above the upper bound estimation. + n5 = num_cross_edge_upper_bound_estimate(n1, n2, n3, n4) + total_memory = estimate_memory_usage(g, n5, + use_bias=False) + return total_memory + + +class SizeAwareBatching: + """A WebDataset composable to do batching based on sample size""" + + def __init__( + self, + max_total_size: int, + size_fn: Callable[[HeteroData], int], + collate_fn: Callable[[List[Any]], Any] = Collater(dataset=None, follow_batch=None, exclude_keys=None), + ): + self.max_total_size = max_total_size + self.size_fn = size_fn + self.collate_fn = collate_fn + self.cached_sizes = {} + + def __call__(self, data: Batch) -> Generator[Batch, None, None]: + batch_size = 0 + batch = [] + + for sample in data: + if sample.name not in self.cached_sizes: + self.cached_sizes[sample.name] = self.size_fn(sample) + sample_size = self.cached_sizes[sample.name] + if sample_size > self.max_total_size: + logging.warning(f"sample {sample.name} has size larger than max size {self.max_total_size}, skipping") + continue + if (batch_size + sample_size) <= self.max_total_size: + batch.append(sample) + batch_size += sample_size + else: + if self.collate_fn is not None: + batch = self.collate_fn(batch) + yield batch + + batch = [sample] + batch_size = sample_size + + +class NoiseTransform(BaseTransform): + """Apply forward diffusion on the ligand + + Args: + t_to_sigma (Callable): Callable to embed time + no_torsion (bool): if not to perturb ligand torsion degrees + all_atom (bool): # all atom or coarse grained/residue for protein + """ + + def __init__(self, t_to_sigma: Callable, no_torsion: bool, all_atom: bool): + self.t_to_sigma = t_to_sigma + self.no_torsion = no_torsion + self.all_atom = all_atom + + def __call__(self, data): + t = np.random.uniform() + t_tr, t_rot, t_tor = t, t, t + return self.apply_noise(data, t_tr, t_rot, t_tor) + + def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update=None, rot_update=None, torsion_updates=None): + if not torch.is_tensor(data["ligand"].pos): + data["ligand"].pos = random.choice(data["ligand"].pos) + + tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) + set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) + + tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update + rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update + torsion_updates = ( + np.random.normal(loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()) + if torsion_updates is None + else torsion_updates + ) + torsion_updates = None if self.no_torsion else torsion_updates + modify_conformer( + data, + tr_update, + torch.from_numpy(rot_update).float(), + None if data["ligand"].edge_mask.sum() == 0 else torsion_updates, + ) + + data.tr_score = -tr_update / tr_sigma**2 + data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0) + data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() + data.tor_sigma_edge = None if self.no_torsion else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma + return data + + def apply_noise_iter(self, source, keep_pos=False): + for (data,) in source: + if keep_pos: + data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) + yield self.__call__(data) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py new file mode 100644 index 0000000000..eb2716af63 --- /dev/null +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest + +from bionemo.contrib.data.molecule.diffdock.datamodule import ScoreModelWDS + + +@pytest.fixture(scope="module") +def get_path(request): + dir_test = os.path.dirname(request.module.__file__) + dir_data = f"{dir_test}/test_data" + return dir_test, dir_data + + +@pytest.fixture(scope="module") +def get_diffdock_score_model_heterodata(get_path, tmp_path_factory): + _, dir_data = get_path + dir_heterodata = f"{dir_data}/molecule/diffdock/heterodata" + suffix_heterodata = "heterodata.pyd" + prefix_dir_tars_wds = tmp_path_factory.mktemp( + "diffdock_score_model_tars_wds").as_posix() + names_subset_train = set(["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", + "7cuo", "7d5c", "7din", "7fha", "7jnb", "7k0v", + "7kb1", "7km8", "7l7c", "7lcu", "7msr", "7my1", + "7n6f", "7np6"]) + names_subset_val = set(["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", + "7qhl", "7rh3", "7rzl", "7sgv"]) + names_subset_test = set(["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", + "7uq3", "7wpw", "7xek", "7xij"]) + return (dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, + names_subset_train, names_subset_val, names_subset_test) + + +@pytest.fixture(scope="module") +def create_ScoreModelWDS(get_diffdock_score_model_heterodata): + (dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, + names_subset_train, names_subset_val, names_subset_test) =\ + get_diffdock_score_model_heterodata + local_batch_size = 2 + global_batch_size = 2 + n_workers_dataloader = 2 + n_tars_wds = 4 + seed_rng_shfl = 822782392 + data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, + prefix_dir_tars_wds, names_subset_train, + names_subset_val, local_batch_size, + global_batch_size, n_workers_dataloader, + n_tars_wds=n_tars_wds, + names_subset_test=names_subset_test, + seed_rng_shfl=seed_rng_shfl) + return data_module + +create_another_ScoreModelWDS = create_ScoreModelWDS diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py new file mode 100644 index 0000000000..11d48d3852 --- /dev/null +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import multiprocessing as mp +import sys + +import lightning + +from bionemo.contrib.data.molecule.diffdock.datamodule import Split + + +def test_ScoreModelWDS_init(get_diffdock_score_model_heterodata, + create_ScoreModelWDS): + (_, _, prefix_dir_tars_wds, names_subset_train, names_subset_val, + names_subset_test) = get_diffdock_score_model_heterodata + data_module = create_ScoreModelWDS + assert data_module._sizes[Split.train] == len(names_subset_train),\ + f"Wrong train set size: expected {len(names_subset_train)}"\ + f"but got {data_module._sizes[Split.train]}" + assert data_module._sizes[Split.val] == len(names_subset_val),\ + f"Wrong val set size: expected {len(names_subset_val)}"\ + f"but got {data_module._sizes[Split.val]}" + assert data_module._sizes[Split.test] == len(names_subset_test),\ + f"Wrong test set size: expected {len(names_subset_test)} "\ + f"but got {data_module._sizes[Split.test]}" + assert data_module._dirs_tars_wds[Split.train] ==\ + f"{prefix_dir_tars_wds}train",\ + f"Wrong tar files directory: expected {prefix_dir_tars_wds}train "\ + f"but got {data_module._dirs_tars_wds[Split.train]}" + assert data_module._dirs_tars_wds[Split.val] ==\ + f"{prefix_dir_tars_wds}val",\ + f"Wrong tar files directory: expected {prefix_dir_tars_wds}val "\ + f"but got {data_module._dirs_tars_wds[Split.val]}" + assert data_module._dirs_tars_wds[Split.test] ==\ + f"{prefix_dir_tars_wds}test",\ + f"Wrong tar files directory: expected {prefix_dir_tars_wds}test "\ + f"but got {data_module._dirs_tars_wds[Split.test]}" + + +def test_ScoreModelWDS_prepare_data(get_diffdock_score_model_heterodata, + create_ScoreModelWDS): + (_, _, prefix_dir_tars_wds, _, _, _) =\ + get_diffdock_score_model_heterodata + data_module = create_ScoreModelWDS + # LightningDataModule.prepare_data() is supposed to be called from the main + # process in a Lightning-managed multi-process context so we can call it in + # a single process + data_module.prepare_data() + files_tars_train = glob.glob( + f"{data_module._dirs_tars_wds[Split.train]}/"\ + f"{data_module._prefix_tars_wds}-*.tar") + assert len(files_tars_train) >= data_module._n_tars_wds,\ + f"Wrong num of train tar files: expected {data_module._n_tars_wds}"\ + f"got {len(files_tars_train)}" + files_tars_val = glob.glob( + f"{data_module._dirs_tars_wds[Split.val]}/"\ + f"{data_module._prefix_tars_wds}-*.tar") + assert len(files_tars_val) >= data_module._n_tars_wds,\ + f"Wrong num of val tar files: expected {data_module._n_tars_wds}"\ + f"got {len(files_tars_val)}" + files_tars_test = glob.glob( + f"{data_module._dirs_tars_wds[Split.test]}/"\ + f"{data_module._prefix_tars_wds}-*.tar") + assert len(files_tars_test) >= data_module._n_tars_wds,\ + f"Wrong num of test tar files: expected {data_module._n_tars_wds}"\ + f"got {len(files_tars_test)}" + + + +def test_ScoreModelWDS_setup(create_ScoreModelWDS, create_another_ScoreModelWDS): + data_modules= [create_ScoreModelWDS, create_another_ScoreModelWDS] + lists_complex_name = [] + stage = "fit" + for m in data_modules: + m.prepare_data() + m.setup(stage) + lightning.seed_everything(2823828) + names = [] + for sample in m._dataset_train: + names.append(sample.name) + lists_complex_name.append(names) + + assert len(lists_complex_name[0]) > 0, "Empty dataset" + # assert lists_complex_name[0] == lists_complex_name[1],\ + # f"Inconsistent data samples from data module instances: "\ + # f"{lists_complex_name} \n\nvs.\n\n"\ + # f"{lists_complex_name}" From 0b8411a71ffbb915f991b3e3bd2ff84171644f43 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 2 Aug 2024 22:33:25 +0000 Subject: [PATCH 05/70] BugFix: control wds shardshuffle randomness This requires upgrading webdataset to the github main branch --- .../bionemo/contrib/data/molecule/diffdock/datamodule.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 2890c34831..333217b475 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -211,11 +211,13 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: """ random.seed(self._seed_rng_shfl) is_train = split == Split.train - urls = glob.glob( + urls = sorted(glob.glob( f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") + ) dataset = ( wds.WebDataset(urls, shardshuffle=is_train, - nodesplitter=wds.split_by_node) + nodesplitter=wds.split_by_node, + seed=self._seed_rng_shfl) .decode() .extract_keys(f"*.{self._suffix_heterodata}") ) From 76ec1c71e7bd79c18bfac82ecc4e432e35fa2bc1 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 2 Aug 2024 22:34:37 +0000 Subject: [PATCH 06/70] Test: refactor ScoreModelWDS fixture to allow comparing different instances of ScoreModelWDS --- .../tests/bionemo/contrib/data/conftest.py | 30 ++++++++++++------- .../contrib/data/test_diffdock_datamodule.py | 21 ++++++------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index eb2716af63..6254416062 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -27,12 +27,10 @@ def get_path(request): @pytest.fixture(scope="module") -def get_diffdock_score_model_heterodata(get_path, tmp_path_factory): +def get_diffdock_score_model_heterodata(get_path): _, dir_data = get_path dir_heterodata = f"{dir_data}/molecule/diffdock/heterodata" suffix_heterodata = "heterodata.pyd" - prefix_dir_tars_wds = tmp_path_factory.mktemp( - "diffdock_score_model_tars_wds").as_posix() names_subset_train = set(["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", "7d5c", "7din", "7fha", "7jnb", "7k0v", "7kb1", "7km8", "7l7c", "7lcu", "7msr", "7my1", @@ -41,15 +39,17 @@ def get_diffdock_score_model_heterodata(get_path, tmp_path_factory): "7qhl", "7rh3", "7rzl", "7sgv"]) names_subset_test = set(["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", "7wpw", "7xek", "7xij"]) - return (dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, + return (dir_heterodata, suffix_heterodata, names_subset_train, names_subset_val, names_subset_test) -@pytest.fixture(scope="module") -def create_ScoreModelWDS(get_diffdock_score_model_heterodata): - (dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, - names_subset_train, names_subset_val, names_subset_test) =\ +def _create_ScoreModelWDS_impl(tmp_path_factory, + get_diffdock_score_model_heterodata): + (dir_heterodata, suffix_heterodata, + names_subset_train, names_subset_val, names_subset_test) =\ get_diffdock_score_model_heterodata + prefix_dir_tars_wds = tmp_path_factory.mktemp( + "diffdock_score_model_tars_wds").as_posix() local_batch_size = 2 global_batch_size = 2 n_workers_dataloader = 2 @@ -62,6 +62,16 @@ def create_ScoreModelWDS(get_diffdock_score_model_heterodata): n_tars_wds=n_tars_wds, names_subset_test=names_subset_test, seed_rng_shfl=seed_rng_shfl) - return data_module + return data_module, prefix_dir_tars_wds -create_another_ScoreModelWDS = create_ScoreModelWDS + +@pytest.fixture(scope="module") +def create_ScoreModelWDS(tmp_path_factory, get_diffdock_score_model_heterodata): + return _create_ScoreModelWDS_impl(tmp_path_factory, + get_diffdock_score_model_heterodata) + + +@pytest.fixture(scope="module") +def create_another_ScoreModelWDS(tmp_path_factory, get_diffdock_score_model_heterodata): + return _create_ScoreModelWDS_impl(tmp_path_factory, + get_diffdock_score_model_heterodata) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 11d48d3852..9386d111af 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -24,9 +24,9 @@ def test_ScoreModelWDS_init(get_diffdock_score_model_heterodata, create_ScoreModelWDS): - (_, _, prefix_dir_tars_wds, names_subset_train, names_subset_val, + (_, _, names_subset_train, names_subset_val, names_subset_test) = get_diffdock_score_model_heterodata - data_module = create_ScoreModelWDS + data_module, prefix_dir_tars_wds = create_ScoreModelWDS assert data_module._sizes[Split.train] == len(names_subset_train),\ f"Wrong train set size: expected {len(names_subset_train)}"\ f"but got {data_module._sizes[Split.train]}" @@ -50,11 +50,8 @@ def test_ScoreModelWDS_init(get_diffdock_score_model_heterodata, f"but got {data_module._dirs_tars_wds[Split.test]}" -def test_ScoreModelWDS_prepare_data(get_diffdock_score_model_heterodata, - create_ScoreModelWDS): - (_, _, prefix_dir_tars_wds, _, _, _) =\ - get_diffdock_score_model_heterodata - data_module = create_ScoreModelWDS +def test_ScoreModelWDS_prepare_data(create_ScoreModelWDS): + data_module, _ = create_ScoreModelWDS # LightningDataModule.prepare_data() is supposed to be called from the main # process in a Lightning-managed multi-process context so we can call it in # a single process @@ -81,7 +78,7 @@ def test_ScoreModelWDS_prepare_data(get_diffdock_score_model_heterodata, def test_ScoreModelWDS_setup(create_ScoreModelWDS, create_another_ScoreModelWDS): - data_modules= [create_ScoreModelWDS, create_another_ScoreModelWDS] + data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] lists_complex_name = [] stage = "fit" for m in data_modules: @@ -94,7 +91,7 @@ def test_ScoreModelWDS_setup(create_ScoreModelWDS, create_another_ScoreModelWDS) lists_complex_name.append(names) assert len(lists_complex_name[0]) > 0, "Empty dataset" - # assert lists_complex_name[0] == lists_complex_name[1],\ - # f"Inconsistent data samples from data module instances: "\ - # f"{lists_complex_name} \n\nvs.\n\n"\ - # f"{lists_complex_name}" + assert lists_complex_name[0] == lists_complex_name[1],\ + f"Inconsistent data samples from data module instances: "\ + f"{lists_complex_name} \n\nvs.\n\n"\ + f"{lists_complex_name}" From ca0cc65cc6ac758a4dc32a84d9f213a8c476c142 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 2 Aug 2024 22:50:39 +0000 Subject: [PATCH 07/70] Test: consistent ligand position after NoiseTransform --- .../contrib/data/test_diffdock_datamodule.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 9386d111af..0b6669d0d4 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -16,6 +16,7 @@ import glob import multiprocessing as mp import sys +import torch import lightning @@ -77,21 +78,39 @@ def test_ScoreModelWDS_prepare_data(create_ScoreModelWDS): -def test_ScoreModelWDS_setup(create_ScoreModelWDS, create_another_ScoreModelWDS): +def test_ScoreModelWDS_setup_dataset(create_ScoreModelWDS, create_another_ScoreModelWDS): data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] lists_complex_name = [] + lists_pos_ligand = [] stage = "fit" for m in data_modules: m.prepare_data() m.setup(stage) lightning.seed_everything(2823828) names = [] + pos_ligand = [] for sample in m._dataset_train: names.append(sample.name) + pos_ligand.append(sample["ligand"].pos) lists_complex_name.append(names) + lists_pos_ligand.append(pos_ligand) - assert len(lists_complex_name[0]) > 0, "Empty dataset" + assert len(lists_complex_name[0]) > 0, "No names in dataset" assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent data samples from data module instances: "\ - f"{lists_complex_name} \n\nvs.\n\n"\ - f"{lists_complex_name}" + f"Inconsistent sample name from data module instances: "\ + f"{lists_complex_name[0]} \n\nvs.\n\n"\ + f"{lists_complex_name[1]}" + + assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ + "Inconsistent number of ligand position from data module instances: "\ + f"{len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + f"{len(lists_pos_ligand[1])}" + for i in range(len(lists_pos_ligand[0])): + pos_0 = lists_pos_ligand[0][i] + pos_1 = lists_pos_ligand[1][i] + torch.testing.assert_close(pos_0, pos_1, + msg=lambda m : + f"Inconsistent ligand position in the " + f"{i}'th sample/batch between two data " + f"module instances:\n\n{m}") From 7d95a67031db0702d56498bcdaa27d02904fcfef Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 2 Aug 2024 23:14:56 +0000 Subject: [PATCH 08/70] BugFix: reduce total_size factor to ensure enough number of shards --- .../src/bionemo/contrib/data/molecule/diffdock/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index c0857ee17c..9f0e933b6c 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -80,7 +80,7 @@ def pickles_to_tars( total_size += os.stat(os.path.join(dir_input, f"{name}.{input_suffix}")).st_size except Exception: continue - maxsize = min(total_size * 0.7 // min_num_shards, maxsize) + maxsize = min(total_size * 0.6 // min_num_shards, maxsize) with wds.ShardWriter(wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777) as sink: for name in input_prefix_subset: try: From b82ee03391ff1238b8613e95d7a726204d5ca707 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 2 Aug 2024 23:28:37 +0000 Subject: [PATCH 09/70] Test: parametrize by split type --- .../data/molecule/diffdock/datamodule.py | 21 ++-- .../tests/bionemo/contrib/data/conftest.py | 32 +++---- .../contrib/data/test_diffdock_datamodule.py | 95 ++++++++----------- 3 files changed, 67 insertions(+), 81 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 333217b475..1fac6625be 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -146,6 +146,9 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, self._pin_memory_dataloader = pin_memory_dataloader self._seed_rng_shfl = seed_rng_shfl + # to be created later in setup + self._dataset = dict() + def _complex_graph_to_tar(self, complex_graph : HeteroData): """map input complex graph to webdataset tar file conforming to its @@ -273,10 +276,10 @@ def setup(self, stage: str) -> None: Returns: None """ if stage == "fit": - self._dataset_train = self._setup_wds(Split.train) - self._dataset_val = self._setup_wds(Split.val) + self._dataset[Split.train] = self._setup_wds(Split.train) + self._dataset[Split.val] = self._setup_wds(Split.val) elif stage == "test": - self._dataset_test = self._setup_wds(Split.test) + self._dataset[Split.test] = self._setup_wds(Split.test) else: raise NotImplementedError("Data setup with stage = {stage}\ is not implmented") @@ -307,21 +310,21 @@ def _setup_dataloader(self, dataset : wds.WebDataset) -> wds.WebLoader: def train_dataloader(self) -> wds.WebLoader: - assert self._dataset_train is not None,\ + assert self._dataset[Split.train] is not None,\ f"dataset for train has not been setup" - return self._setup_dataloader(self._dataset_train) + return self._setup_dataloader(self._dataset[Split.train]) def val_dataloader(self) -> wds.WebLoader: - assert self._dataset_val is not None,\ + assert self._dataset[Split.val] is not None,\ f"dataset for val has not been setup" - return self._setup_dataloader(self._dataset_val) + return self._setup_dataloader(self._dataset[Split.val]) def test_dataloader(self) -> wds.WebLoader: - assert self._dataset_test is not None,\ + assert self._dataset[Split.test] is not None,\ f"dataset for test has not been setup" - return self._setup_dataloader(self._dataset_test) + return self._setup_dataloader(self._dataset[Split.test]) def predict_dataloader(self) -> wds.WebLoader: diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 6254416062..fc7a4daf43 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -16,7 +16,7 @@ import os import pytest -from bionemo.contrib.data.molecule.diffdock.datamodule import ScoreModelWDS +from bionemo.contrib.data.molecule.diffdock.datamodule import Split, ScoreModelWDS @pytest.fixture(scope="module") @@ -31,22 +31,22 @@ def get_diffdock_score_model_heterodata(get_path): _, dir_data = get_path dir_heterodata = f"{dir_data}/molecule/diffdock/heterodata" suffix_heterodata = "heterodata.pyd" - names_subset_train = set(["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", - "7cuo", "7d5c", "7din", "7fha", "7jnb", "7k0v", - "7kb1", "7km8", "7l7c", "7lcu", "7msr", "7my1", - "7n6f", "7np6"]) - names_subset_val = set(["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", - "7qhl", "7rh3", "7rzl", "7sgv"]) - names_subset_test = set(["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", - "7uq3", "7wpw", "7xek", "7xij"]) - return (dir_heterodata, suffix_heterodata, - names_subset_train, names_subset_val, names_subset_test) + names = { + Split.train : set(["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", + "7cuo", "7d5c", "7din", "7fha", "7jnb", "7k0v", + "7kb1", "7km8", "7l7c", "7lcu", "7msr", "7my1", + "7n6f", "7np6"]), + Split.val : set(["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", + "7rh3", "7rzl", "7sgv"]), + Split.test : set(["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", + "7uq3", "7wpw", "7xek", "7xij"]) + } + return (dir_heterodata, suffix_heterodata, names) def _create_ScoreModelWDS_impl(tmp_path_factory, get_diffdock_score_model_heterodata): - (dir_heterodata, suffix_heterodata, - names_subset_train, names_subset_val, names_subset_test) =\ + (dir_heterodata, suffix_heterodata, names) =\ get_diffdock_score_model_heterodata prefix_dir_tars_wds = tmp_path_factory.mktemp( "diffdock_score_model_tars_wds").as_posix() @@ -56,11 +56,11 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, n_tars_wds = 4 seed_rng_shfl = 822782392 data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, - prefix_dir_tars_wds, names_subset_train, - names_subset_val, local_batch_size, + prefix_dir_tars_wds, names[Split.train], + names[Split.val], local_batch_size, global_batch_size, n_workers_dataloader, n_tars_wds=n_tars_wds, - names_subset_test=names_subset_test, + names_subset_test=names[Split.test], seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 0b6669d0d4..63e88b08ab 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -14,8 +14,7 @@ # limitations under the License. import glob -import multiprocessing as mp -import sys +import pytest import torch import lightning @@ -23,88 +22,72 @@ from bionemo.contrib.data.molecule.diffdock.datamodule import Split -def test_ScoreModelWDS_init(get_diffdock_score_model_heterodata, +@pytest.mark.parametrize("split", [s for s in Split]) +def test_ScoreModelWDS_init(split, get_diffdock_score_model_heterodata, create_ScoreModelWDS): - (_, _, names_subset_train, names_subset_val, - names_subset_test) = get_diffdock_score_model_heterodata + name_split = str(split).split('.')[1] + (_, _, names) = get_diffdock_score_model_heterodata data_module, prefix_dir_tars_wds = create_ScoreModelWDS - assert data_module._sizes[Split.train] == len(names_subset_train),\ - f"Wrong train set size: expected {len(names_subset_train)}"\ - f"but got {data_module._sizes[Split.train]}" - assert data_module._sizes[Split.val] == len(names_subset_val),\ - f"Wrong val set size: expected {len(names_subset_val)}"\ - f"but got {data_module._sizes[Split.val]}" - assert data_module._sizes[Split.test] == len(names_subset_test),\ - f"Wrong test set size: expected {len(names_subset_test)} "\ - f"but got {data_module._sizes[Split.test]}" - assert data_module._dirs_tars_wds[Split.train] ==\ - f"{prefix_dir_tars_wds}train",\ - f"Wrong tar files directory: expected {prefix_dir_tars_wds}train "\ - f"but got {data_module._dirs_tars_wds[Split.train]}" - assert data_module._dirs_tars_wds[Split.val] ==\ - f"{prefix_dir_tars_wds}val",\ - f"Wrong tar files directory: expected {prefix_dir_tars_wds}val "\ - f"but got {data_module._dirs_tars_wds[Split.val]}" - assert data_module._dirs_tars_wds[Split.test] ==\ - f"{prefix_dir_tars_wds}test",\ - f"Wrong tar files directory: expected {prefix_dir_tars_wds}test "\ - f"but got {data_module._dirs_tars_wds[Split.test]}" + assert data_module._sizes[split] == len(names[split]),\ + f"Wrong {split}-set size: expected {len(names[split])}"\ + f"but got {data_module._sizes[split]}" + assert data_module._dirs_tars_wds[split] ==\ + f"{prefix_dir_tars_wds}{name_split}",\ + f"Wrong tar files directory: expected {prefix_dir_tars_wds}{split} "\ + f"but got {data_module._dirs_tars_wds[split]}" -def test_ScoreModelWDS_prepare_data(create_ScoreModelWDS): +@pytest.mark.parametrize("split", [s for s in Split]) +def test_ScoreModelWDS_prepare_data(split, create_ScoreModelWDS): data_module, _ = create_ScoreModelWDS # LightningDataModule.prepare_data() is supposed to be called from the main # process in a Lightning-managed multi-process context so we can call it in # a single process data_module.prepare_data() - files_tars_train = glob.glob( - f"{data_module._dirs_tars_wds[Split.train]}/"\ - f"{data_module._prefix_tars_wds}-*.tar") - assert len(files_tars_train) >= data_module._n_tars_wds,\ - f"Wrong num of train tar files: expected {data_module._n_tars_wds}"\ - f"got {len(files_tars_train)}" - files_tars_val = glob.glob( - f"{data_module._dirs_tars_wds[Split.val]}/"\ - f"{data_module._prefix_tars_wds}-*.tar") - assert len(files_tars_val) >= data_module._n_tars_wds,\ - f"Wrong num of val tar files: expected {data_module._n_tars_wds}"\ - f"got {len(files_tars_val)}" - files_tars_test = glob.glob( - f"{data_module._dirs_tars_wds[Split.test]}/"\ - f"{data_module._prefix_tars_wds}-*.tar") - assert len(files_tars_test) >= data_module._n_tars_wds,\ - f"Wrong num of test tar files: expected {data_module._n_tars_wds}"\ - f"got {len(files_tars_test)}" + files_tars = sorted(glob.glob( + f"{data_module._dirs_tars_wds[split]}/"\ + f"{data_module._prefix_tars_wds}-*.tar")) + assert len(files_tars) >= data_module._n_tars_wds,\ + f"Wrong num of {split}-set tar files: "\ + f"expected {data_module._n_tars_wds} "\ + f"got {len(files_tars)}" - -def test_ScoreModelWDS_setup_dataset(create_ScoreModelWDS, create_another_ScoreModelWDS): +@pytest.mark.parametrize("split", [s for s in Split]) +def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another_ScoreModelWDS): data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] lists_complex_name = [] lists_pos_ligand = [] - stage = "fit" for m in data_modules: m.prepare_data() - m.setup(stage) + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") lightning.seed_everything(2823828) names = [] pos_ligand = [] - for sample in m._dataset_train: + for sample in m._dataset[split]: + if isinstance(sample, list): + assert len(sample) == 1,\ + "Uncollated sample batch returned as list" + sample = sample[0] names.append(sample.name) pos_ligand.append(sample["ligand"].pos) lists_complex_name.append(names) lists_pos_ligand.append(pos_ligand) - assert len(lists_complex_name[0]) > 0, "No names in dataset" + assert len(lists_complex_name[0]) > 0,\ + "No names in {split} dataset" assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent sample name from data module instances: "\ + f"Inconsistent sample name in {split}-set from data module instances: "\ f"{lists_complex_name[0]} \n\nvs.\n\n"\ f"{lists_complex_name[1]}" assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ - "Inconsistent number of ligand position from data module instances: "\ - f"{len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + f"Inconsistent number of ligand position in {split}-set from data "\ + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ f"{len(lists_pos_ligand[1])}" for i in range(len(lists_pos_ligand[0])): pos_0 = lists_pos_ligand[0][i] @@ -112,5 +95,5 @@ def test_ScoreModelWDS_setup_dataset(create_ScoreModelWDS, create_another_ScoreM torch.testing.assert_close(pos_0, pos_1, msg=lambda m : f"Inconsistent ligand position in the " - f"{i}'th sample/batch between two data " - f"module instances:\n\n{m}") + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}") From 8420700245e2dda1c575f22f09a797b41070d35b Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 3 Aug 2024 00:20:06 +0000 Subject: [PATCH 10/70] Enhancement: move NoiseTransform to GenerateNoise in the diffusion module GenerateNoise is now a webdataset composable (instead of a BaseTransform) that maps a generator to another --- .../data/molecule/diffdock/datamodule.py | 18 +++--- .../contrib/data/molecule/diffdock/utils.py | 54 ----------------- .../molecule/diffdock/utils/diffusion.py | 58 +++++++++++++++++++ 3 files changed, 68 insertions(+), 62 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 1fac6625be..18d1529814 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -25,8 +25,11 @@ from torch_geometric.loader.dataloader import Collater import webdataset as wds -from bionemo.contrib.data.molecule.diffdock.utils import pickles_to_tars, NoiseTransform, SizeAwareBatching, estimate_size -from bionemo.contrib.model.molecule.diffdock.utils.diffusion import t_to_sigma +from bionemo.contrib.data.molecule.diffdock.utils import ( + pickles_to_tars, SizeAwareBatching, estimate_size + ) +from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( + t_to_sigma, GenerateNoise) class Split(Enum): @@ -225,17 +228,16 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: .extract_keys(f"*.{self._suffix_heterodata}") ) if is_train: - self._xform_train = ( - NoiseTransform(partial(t_to_sigma, + dataset = (dataset.compose( + GenerateNoise(partial(t_to_sigma, self._tr_sigma_min, self._tr_sigma_max, self._rot_sigma_min, self._rot_sigma_max, self._tor_sigma_min, self._tor_sigma_max), - self._no_torsion, self._is_all_atom)) - dataset = (dataset - .compose(partial(self._xform_train.apply_noise_iter, - keep_pos=(split == Split.val))) + self._no_torsion, + self._is_all_atom, + copy_ref_pos=(split == Split.val))) ) # sandwiched here to mirror the original DiffDock FW implementation size = self._sizes[split] diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index 9f0e933b6c..d4b63bde99 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -185,57 +185,3 @@ def __call__(self, data: Batch) -> Generator[Batch, None, None]: batch = [sample] batch_size = sample_size - - -class NoiseTransform(BaseTransform): - """Apply forward diffusion on the ligand - - Args: - t_to_sigma (Callable): Callable to embed time - no_torsion (bool): if not to perturb ligand torsion degrees - all_atom (bool): # all atom or coarse grained/residue for protein - """ - - def __init__(self, t_to_sigma: Callable, no_torsion: bool, all_atom: bool): - self.t_to_sigma = t_to_sigma - self.no_torsion = no_torsion - self.all_atom = all_atom - - def __call__(self, data): - t = np.random.uniform() - t_tr, t_rot, t_tor = t, t, t - return self.apply_noise(data, t_tr, t_rot, t_tor) - - def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update=None, rot_update=None, torsion_updates=None): - if not torch.is_tensor(data["ligand"].pos): - data["ligand"].pos = random.choice(data["ligand"].pos) - - tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) - set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) - - tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update - rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update - torsion_updates = ( - np.random.normal(loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()) - if torsion_updates is None - else torsion_updates - ) - torsion_updates = None if self.no_torsion else torsion_updates - modify_conformer( - data, - tr_update, - torch.from_numpy(rot_update).float(), - None if data["ligand"].edge_mask.sum() == 0 else torsion_updates, - ) - - data.tr_score = -tr_update / tr_sigma**2 - data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0) - data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() - data.tor_sigma_edge = None if self.no_torsion else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma - return data - - def apply_noise_iter(self, source, keep_pos=False): - for (data,) in source: - if keep_pos: - data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) - yield self.__call__(data) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py index 9b080e6b33..edb19a7641 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py @@ -9,12 +9,16 @@ # its affiliates is strictly prohibited. import math +import random +from copy import deepcopy +from typing import Callable import numpy as np import torch import torch.nn.functional as F from torch import nn +from bionemo.contrib.model.molecule.diffdock.utils import so3, torus from bionemo.contrib.model.molecule.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch from bionemo.contrib.model.molecule.diffdock.utils.torsion import modify_conformer_torsion_angles @@ -145,3 +149,57 @@ def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device): "rot": t_rot * torch.ones(complex_graphs["atom"].num_nodes).to(device), "tor": t_tor * torch.ones(complex_graphs["atom"].num_nodes).to(device), } + + +class GenerateNoise: + """Apply forward diffusion on the ligand + + Args: + t_to_sigma (Callable): Callable to embed time + no_torsion (bool): if not to perturb ligand torsion degrees + all_atom (bool): all atom or coarse grained/residue for protein + copy_ref_pos (bool): whether or not make a copy of the input ligand position + """ + + def __init__(self, t_to_sigma: Callable, no_torsion: bool, all_atom: bool, + copy_ref_pos: bool = False): + self.t_to_sigma = t_to_sigma + self.no_torsion = no_torsion + self.all_atom = all_atom + self._copy_ref_pos = copy_ref_pos + + def __call__(self, source): + for (data,) in source: + if self._copy_ref_pos: + data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) + t = np.random.uniform() + t_tr, t_rot, t_tor = t, t, t + yield self.apply_noise(data, t_tr, t_rot, t_tor) + + def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update=None, rot_update=None, torsion_updates=None): + if not torch.is_tensor(data["ligand"].pos): + data["ligand"].pos = random.choice(data["ligand"].pos) + + tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) + set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) + + tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update + rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update + torsion_updates = ( + np.random.normal(loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()) + if torsion_updates is None + else torsion_updates + ) + torsion_updates = None if self.no_torsion else torsion_updates + modify_conformer( + data, + tr_update, + torch.from_numpy(rot_update).float(), + None if data["ligand"].edge_mask.sum() == 0 else torsion_updates, + ) + + data.tr_score = -tr_update / tr_sigma**2 + data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0) + data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() + data.tor_sigma_edge = None if self.no_torsion else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma + return data From 662ce27cb9c552c0c0318238dc86d1ab53b4c200 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 3 Aug 2024 00:32:37 +0000 Subject: [PATCH 11/70] BugFix: always GenerateNoise in score model data --- .../data/molecule/diffdock/datamodule.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 18d1529814..8b761e99b2 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -227,18 +227,14 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: .decode() .extract_keys(f"*.{self._suffix_heterodata}") ) - if is_train: - dataset = (dataset.compose( - GenerateNoise(partial(t_to_sigma, - self._tr_sigma_min, self._tr_sigma_max, - self._rot_sigma_min, - self._rot_sigma_max, - self._tor_sigma_min, - self._tor_sigma_max), - self._no_torsion, - self._is_all_atom, - copy_ref_pos=(split == Split.val))) - ) + dataset = dataset.compose( + GenerateNoise(partial(t_to_sigma, + self._tr_sigma_min, self._tr_sigma_max, + self._rot_sigma_min, self._rot_sigma_max, + self._tor_sigma_min, + self._tor_sigma_max), + self._no_torsion, self._is_all_atom, + copy_ref_pos=(split == Split.val))) # sandwiched here to mirror the original DiffDock FW implementation size = self._sizes[split] # FIXME: remove this with_length since it's overriden later anyway From 99d8e6eb7ee7b0b518fd5d5239347d982c5d7f4c Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 3 Aug 2024 00:43:31 +0000 Subject: [PATCH 12/70] Regress: with_length -> with_epoch ... to be consistent with the FW v1.0 behavior --- .../src/bionemo/contrib/data/molecule/diffdock/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 8b761e99b2..513c3e67b2 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -258,7 +258,7 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, size_fn=estimate_size, ) - dataset = dataset.compose(f_batching).with_length(n_batches) + dataset = dataset.compose(f_batching).with_epoch(n_batches) if is_train: dataset = dataset.select(lambda x: len(x) > 1) From e604a49fc8cda01f82d3c9bcf355b086b4789f9b Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 5 Aug 2024 22:26:53 +0000 Subject: [PATCH 13/70] BugFix: define length on SizeAwareBatching dataset --- .../bionemo/contrib/data/molecule/diffdock/datamodule.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 513c3e67b2..80ff429c52 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -237,8 +237,6 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: copy_ref_pos=(split == Split.val))) # sandwiched here to mirror the original DiffDock FW implementation size = self._sizes[split] - # FIXME: remove this with_length since it's overriden later anyway - dataset = dataset.with_length(size) if is_train: dataset = dataset.shuffle(size=5000, rng=random.Random(self._seed_rng_shfl)) @@ -258,7 +256,10 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, size_fn=estimate_size, ) - dataset = dataset.compose(f_batching).with_epoch(n_batches) + dataset = (dataset.compose(f_batching) + .with_epoch(n_batches) + .with_length(n_batches) + ) if is_train: dataset = dataset.select(lambda x: len(x) > 1) From 8e236d6c7988c47b39d78c6737376a05cd187730 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 5 Aug 2024 22:27:36 +0000 Subject: [PATCH 14/70] Test: setup data loader --- .../contrib/data/test_diffdock_datamodule.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 63e88b08ab..1ae7f4a18e 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -97,3 +97,60 @@ def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another f"Inconsistent ligand position in the " f"{i}'th sample/batch of {split}-set " f"between two data module instances:\n\n{m}") + + +@pytest.mark.parametrize("split", [s for s in Split]) +def test_ScoreModelWDS_setup_dataloader(split, create_ScoreModelWDS, create_another_ScoreModelWDS): + data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] + lists_complex_name = [] + lists_pos_ligand = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + lightning.seed_everything(2823828) + names = [] + pos_ligand = [] + loader = None + if split == Split.train: + loader = m.train_dataloader() + elif split == Split.val: + loader = m.val_dataloader() + elif split == Split.test: + loader = m.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + for sample in loader: + if isinstance(sample, list): + assert len(sample) == 1,\ + "Uncollated sample batch returned as list" + sample = sample[0] + names.append(sample.name) + pos_ligand.append(sample["ligand"].pos) + lists_complex_name.append(names) + lists_pos_ligand.append(pos_ligand) + + assert len(lists_complex_name[0]) > 0,\ + "No names in {split} dataloader" + assert lists_complex_name[0] == lists_complex_name[1],\ + f"Inconsistent sample name in {split}-set from data module instances: "\ + f"{lists_complex_name[0]} \n\nvs.\n\n"\ + f"{lists_complex_name[1]}" + + assert len(lists_pos_ligand[0]) > 0,\ + "No ligand position found in dataloader" + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ + f"Inconsistent number of ligand position in {split}-set from data "\ + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + f"{len(lists_pos_ligand[1])}" + for i in range(len(lists_pos_ligand[0])): + pos_0 = lists_pos_ligand[0][i] + pos_1 = lists_pos_ligand[1][i] + torch.testing.assert_close(pos_0, pos_1, + msg=lambda m : + f"Inconsistent ligand position in the " + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}") From 0e4c7bdcc3f63d857ad16d7003abe6d7dac38ebf Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 5 Aug 2024 23:25:04 +0000 Subject: [PATCH 15/70] Enhancement: factor out GenerateNoise outside of data module ScoreModelWDS now takes a dict: Split -> Generator that is passed down to the setup() to the respective wds.compose() calls so it no longer maintains the generator's specific parameters --- .../data/molecule/diffdock/datamodule.py | 55 +++++-------------- .../molecule/diffdock/utils/diffusion.py | 17 ++++-- .../tests/bionemo/contrib/data/conftest.py | 20 +++++++ 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 80ff429c52..e50ddce8f3 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -14,11 +14,10 @@ # limitations under the License. from enum import Enum, auto -from functools import partial import glob import pickle import random -from typing import Set, Optional, Tuple +from typing import Dict, Generator, Set, Optional, Tuple import lightning as L import torch from torch_geometric.data.hetero_data import HeteroData @@ -28,8 +27,6 @@ from bionemo.contrib.data.molecule.diffdock.utils import ( pickles_to_tars, SizeAwareBatching, estimate_size ) -from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( - t_to_sigma, GenerateNoise) class Split(Enum): @@ -47,15 +44,12 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, prefix_dir_tars_wds : str, names_subset_train : Set[str], names_subset_val : Set[str], local_batch_size : int, global_batch_size : int, n_workers_dataloader : int, - tr_sigma_minmax : Tuple[float, float] - = (0.1, 19), rot_sigma_minmax : Tuple[float, float] = (0.03, - 1.55), - tor_sigma_minmax : Optional[Tuple[float, float]] = (0.0314, - 3.14), - is_all_atom : bool = False, apply_size_control : Tuple[bool, - bool, - bool] = - (True, False, False), pin_memory_dataloader : bool = True, + xform_gen_wds : Optional[Dict[Split, Generator[HeteroData, + None, None]]] = + None, + apply_size_control : Tuple[bool, bool, bool] = (True, False, + False), + pin_memory_dataloader : bool = True, prefix_tars_wds : str = "heterographs", n_tars_wds : Optional[int] = None, names_subset_test : Optional[Set[str]] = None, seed_rng_shfl : int = 0): @@ -84,15 +78,11 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, data loading time for shuffling Kwargs: - tr_sigma_minmax (Tuple[float, float]): min and max sigma for the - translational component during diffusion - rot_sigma_minmax (Tuple[float, float]): min and max sigma for the - rotational component during diffusion - tor_sigma_minmax (Optional[Tuple[float, float]]): min and max sigma - for the torsional component during diffusion - is_all_atom (bool): whether to treat the data as all-atom system - during noise transformation - apply_size_control(Tuple[bool, bool, bool]): whether to use + xform_gen_wds (Optional[Dict[Split, Generator[HeteroData, None, + None]]]): a dictionary of webdatast composable, i.e., functor that + maps a generator to another generator that transforms the data + sample, for different splits + apply_size_control (Tuple[bool, bool, bool]): whether to use SizeAwareBatching for the respective train, val and test data pin_memory_dataloader (bool): whether to use pin memory in pytorch dataloader @@ -127,16 +117,7 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, Split.test : f"{self._prefix_dir_tars_wds}test", } - self._tr_sigma_min, self._tr_sigma_max = tr_sigma_minmax - self._rot_sigma_min, self._rot_sigma_max = rot_sigma_minmax - self._tor_sigma_min, self._tor_sigma_max = (None, None) - self._no_torsion = True - if tor_sigma_minmax is not None: - self._tor_sigma_min, self._tor_sigma_max = tor_sigma_minmax - self._no_torsion = False - # TODO: the all-atom arg to set_time should be inferred from the - # complex_graph arg so we don't have to pass it all-the-way down - self._is_all_atom = is_all_atom + self._xform_gen_wds = xform_gen_wds self._local_batch_size = local_batch_size self._global_batch_size = global_batch_size @@ -227,14 +208,8 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: .decode() .extract_keys(f"*.{self._suffix_heterodata}") ) - dataset = dataset.compose( - GenerateNoise(partial(t_to_sigma, - self._tr_sigma_min, self._tr_sigma_max, - self._rot_sigma_min, self._rot_sigma_max, - self._tor_sigma_min, - self._tor_sigma_max), - self._no_torsion, self._is_all_atom, - copy_ref_pos=(split == Split.val))) + if self._xform_gen_wds is not None: + dataset = dataset.compose(self._xform_gen_wds[split]) # sandwiched here to mirror the original DiffDock FW implementation size = self._sizes[split] if is_train: diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py index edb19a7641..79b42a75f9 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py @@ -11,20 +11,22 @@ import math import random from copy import deepcopy -from typing import Callable +from typing import Callable, Generator, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn +from torch_geometric.data.hetero_data import HeteroData from bionemo.contrib.model.molecule.diffdock.utils import so3, torus from bionemo.contrib.model.molecule.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch from bionemo.contrib.model.molecule.diffdock.utils.torsion import modify_conformer_torsion_angles -def t_to_sigma(t_tr, t_rot, t_tor, tr_sigma_min, tr_sigma_max, rot_sigma_min, - rot_sigma_max, tor_sigma_min, tor_sigma_max): +def t_to_sigma(tr_sigma_min, tr_sigma_max, rot_sigma_min, + rot_sigma_max, tor_sigma_min, tor_sigma_max, + t_tr, t_rot, t_tor): tr_sigma = tr_sigma_min ** (1 - t_tr) * tr_sigma_max**t_tr rot_sigma = rot_sigma_min ** (1 - t_rot) * rot_sigma_max**t_rot tor_sigma = tor_sigma_min ** (1 - t_tor) * tor_sigma_max**t_tor @@ -161,14 +163,17 @@ class GenerateNoise: copy_ref_pos (bool): whether or not make a copy of the input ligand position """ - def __init__(self, t_to_sigma: Callable, no_torsion: bool, all_atom: bool, - copy_ref_pos: bool = False): + def __init__(self, t_to_sigma: Callable[[float, float, float], Tuple[float, + float, + float]], + no_torsion: bool, all_atom: bool, copy_ref_pos: bool = False): self.t_to_sigma = t_to_sigma self.no_torsion = no_torsion self.all_atom = all_atom self._copy_ref_pos = copy_ref_pos - def __call__(self, source): + def __call__(self, source : Generator[HeteroData, None, None]) \ + -> Generator[HeteroData, None, None]: for (data,) in source: if self._copy_ref_pos: data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index fc7a4daf43..a5d284be5d 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -15,7 +15,10 @@ import os import pytest +from functools import partial +from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( + t_to_sigma, GenerateNoise) from bionemo.contrib.data.molecule.diffdock.datamodule import Split, ScoreModelWDS @@ -50,6 +53,22 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, get_diffdock_score_model_heterodata prefix_dir_tars_wds = tmp_path_factory.mktemp( "diffdock_score_model_tars_wds").as_posix() + tr_sigma_min, tr_sigma_max = (0.1, 19) + rot_sigma_min, rot_sigma_max = (0.03, 1.55) + tor_sigma_min, tor_sigma_max = (0.0314, 3.14) + is_all_atom = False + no_torsion = False + sigma_t = partial(t_to_sigma, tr_sigma_min, + tr_sigma_max, rot_sigma_min, rot_sigma_max, + tor_sigma_min, tor_sigma_max) + generateNoise = { + Split.train : GenerateNoise(sigma_t, no_torsion, is_all_atom, + copy_ref_pos=False), + Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, + copy_ref_pos=True), + Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, + copy_ref_pos=False), + } local_batch_size = 2 global_batch_size = 2 n_workers_dataloader = 2 @@ -59,6 +78,7 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, prefix_dir_tars_wds, names[Split.train], names[Split.val], local_batch_size, global_batch_size, n_workers_dataloader, + generateNoise, n_tars_wds=n_tars_wds, names_subset_test=names[Split.test], seed_rng_shfl=seed_rng_shfl) From 88035d61046ab4cc0119de7f44827a4cfeac4ecd Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 5 Aug 2024 23:44:51 +0000 Subject: [PATCH 16/70] Enhancement: use a dict: Split -> List[str] as the input subset complex names ... replacing the original seperate arguments for each split --- .../data/molecule/diffdock/datamodule.py | 76 +++++++------------ .../contrib/data/molecule/diffdock/utils.py | 4 +- .../tests/bionemo/contrib/data/conftest.py | 22 +++--- 3 files changed, 38 insertions(+), 64 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index e50ddce8f3..b25fb843b2 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -17,7 +17,7 @@ import glob import pickle import random -from typing import Dict, Generator, Set, Optional, Tuple +from typing import Dict, Generator, List, Set, Optional, Tuple import lightning as L import torch from torch_geometric.data.hetero_data import HeteroData @@ -41,18 +41,17 @@ class ScoreModelWDS(L.LightningDataModule): dataloader""" def __init__(self, dir_heterodata : str, suffix_heterodata : str, - prefix_dir_tars_wds : str, names_subset_train : Set[str], - names_subset_val : Set[str], local_batch_size : int, - global_batch_size : int, n_workers_dataloader : int, - xform_gen_wds : Optional[Dict[Split, Generator[HeteroData, - None, None]]] = - None, - apply_size_control : Tuple[bool, bool, bool] = (True, False, - False), - pin_memory_dataloader : bool = True, - prefix_tars_wds : str = "heterographs", - n_tars_wds : Optional[int] = None, names_subset_test : - Optional[Set[str]] = None, seed_rng_shfl : int = 0): + prefix_dir_tars_wds : str, names_subset : Dict[Split, + List[str]], + local_batch_size : int, global_batch_size : int, + n_workers_dataloader : int, xform_gen_wds : + Optional[Dict[Split, Generator[HeteroData, None, None]]] = + None, apply_size_control : Tuple[bool, bool, bool] = (True, + False, + False), + pin_memory_dataloader : bool = True, prefix_tars_wds : str = + "heterographs", n_tars_wds : Optional[int] = None, + seed_rng_shfl : int = 0): """constructor Args: @@ -65,10 +64,8 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, webdataset tar files. The actual directories storing the train, val and test sets will be suffixed with "train", "val" and "test" respectively. - names_subset_train (Set[str]): list of complex names to be included - in the training data - names_subset_val (Set[str]): list of complex names to be included - in the validation data + names_subset (Dict[Split, List[str]]): list of complex names to be + included in each of the split local_batch_size (int): size of batch for each node global_batch_size (int): size of batch summing across nodes in Data Distributed Parallel, i.e., local_batch_size * n_nodes @@ -88,8 +85,6 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, dataloader prefix_tars_wds (str): name prefix to output webdataset tar files n_tars_wds (int): attempt to create at least this number of webdataset shards - names_subset_test (Optional[Set[str]]): list of complex names to be included - in the test data """ @@ -100,15 +95,17 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, self._n_tars_wds = n_tars_wds self._prefix_dir_tars_wds = prefix_dir_tars_wds self._prefix_tars_wds = prefix_tars_wds - self._names_subset_train = names_subset_train - self._names_subset_val = names_subset_val - self._names_subset_test = names_subset_test + + keys_subset = names_subset.keys() + if not (Split.train in keys_subset and Split.val in keys_subset): + raise RuntimeError("Input names_subset must be defined for the "\ + "train and val splits") + + self._names_subset = names_subset self._sizes = { - Split.train : len(self._names_subset_train), - Split.val : len(self._names_subset_val), - Split.test : len(self._names_subset_test) if - self._names_subset_test is not None else None, + split : len(self._names_subset[split]) for split in + self._names_subset.keys() } self._dirs_tars_wds = { @@ -157,36 +154,17 @@ def prepare_data(self) -> None: Returns: None """ - # create wds shards (tar files) for train set - pickles_to_tars(self._dir_heterodata, - self._suffix_heterodata, - self._names_subset_train, - self._dirs_tars_wds[Split.train], - self._prefix_tars_wds, - self._complex_graph_to_tar, - min_num_shards=self._n_tars_wds) - - # create wds shards (tar files) for val set - pickles_to_tars(self._dir_heterodata, - self._suffix_heterodata, - self._names_subset_val, - self._dirs_tars_wds[Split.val], - self._prefix_tars_wds, - self._complex_graph_to_tar, - min_num_shards=self._n_tars_wds) - - if self._names_subset_test is not None: - # create wds shards (tar files) for test set + for split in self._names_subset.keys(): + # create wds shards (tar files) for train set pickles_to_tars(self._dir_heterodata, self._suffix_heterodata, - self._names_subset_test, - self._dirs_tars_wds[Split.test], + self._names_subset[split], + self._dirs_tars_wds[split], self._prefix_tars_wds, self._complex_graph_to_tar, min_num_shards=self._n_tars_wds) - def _setup_wds(self, split : Split) -> wds.WebDataset: """setup webdataset and webloader. This is called by setup() diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index d4b63bde99..ef8359fed6 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -36,7 +36,7 @@ def pickles_to_tars( dir_input: str, input_suffix: str, - input_prefix_subset: set, + input_prefix_subset: List[str], dir_output: str, output_prefix: str, func_output_data: Callable = lambda data: {"data": pickle.dumps(data)}, @@ -51,7 +51,7 @@ def pickles_to_tars( Args: dir_input (str): Input directory input_suffix (str): Input pickle file name suffix - input_prefix_subset (set): Input subset of pickle files' prefix + input_prefix_subset (List[str]): Input subset of pickle files' prefix dir_output (str): Output directory output_prefix (str): Output tar file name prefix func_output_data (Callable) : function that maps data to a dictionary diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index a5d284be5d..8414e97597 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -35,14 +35,13 @@ def get_diffdock_score_model_heterodata(get_path): dir_heterodata = f"{dir_data}/molecule/diffdock/heterodata" suffix_heterodata = "heterodata.pyd" names = { - Split.train : set(["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", - "7cuo", "7d5c", "7din", "7fha", "7jnb", "7k0v", - "7kb1", "7km8", "7l7c", "7lcu", "7msr", "7my1", - "7n6f", "7np6"]), - Split.val : set(["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", - "7rh3", "7rzl", "7sgv"]), - Split.test : set(["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", - "7uq3", "7wpw", "7xek", "7xij"]) + Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", + "7d5c", "7din", "7fha", "7jnb", "7k0v", "7kb1", "7km8", + "7l7c", "7lcu", "7msr", "7my1", "7n6f", "7np6"], + Split.val : ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", + "7rh3", "7rzl", "7sgv"], + Split.test : ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", + "7wpw", "7xek", "7xij"] } return (dir_heterodata, suffix_heterodata, names) @@ -75,12 +74,9 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, n_tars_wds = 4 seed_rng_shfl = 822782392 data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, - prefix_dir_tars_wds, names[Split.train], - names[Split.val], local_batch_size, + prefix_dir_tars_wds, names, local_batch_size, global_batch_size, n_workers_dataloader, - generateNoise, - n_tars_wds=n_tars_wds, - names_subset_test=names[Split.test], + generateNoise, n_tars_wds=n_tars_wds, seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds From 377f276bb8007aacc095cbfd8bbfd81a7e119294 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 5 Aug 2024 23:58:23 +0000 Subject: [PATCH 17/70] BugFix: check _xform_gen_wds's avail before composing --- .../src/bionemo/contrib/data/molecule/diffdock/datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index b25fb843b2..23c3758197 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -186,7 +186,8 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: .decode() .extract_keys(f"*.{self._suffix_heterodata}") ) - if self._xform_gen_wds is not None: + if (self._xform_gen_wds is not None and + self._xform_gen_wds[split] is not None): dataset = dataset.compose(self._xform_gen_wds[split]) # sandwiched here to mirror the original DiffDock FW implementation size = self._sizes[split] From 7c5add10ca2762577ff6898fe605c287b783bc07 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 02:53:52 +0000 Subject: [PATCH 18/70] Enhancement: move the batching from wds to webloader ... to keep to the standard dataset and dataloader behaviors This doesn't do any batching at all in the webdataset object because batch copy from the workers probably won't worth the extra overhead of first making a list and then yielding from it, which in turn because the copying of HeteroData (and the list thereof) will incur multiple memcpy calls anyway given its dictionary nature --- .../data/molecule/diffdock/datamodule.py | 68 +++++++++---------- .../contrib/data/test_diffdock_datamodule.py | 5 +- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 23c3758197..c53fff393f 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -189,34 +189,9 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: if (self._xform_gen_wds is not None and self._xform_gen_wds[split] is not None): dataset = dataset.compose(self._xform_gen_wds[split]) - # sandwiched here to mirror the original DiffDock FW implementation - size = self._sizes[split] if is_train: - dataset = dataset.shuffle(size=5000, + dataset = dataset.shuffle(size=16, rng=random.Random(self._seed_rng_shfl)) - n_batches = ((size + self._global_batch_size - 1) - // self._global_batch_size) - if not self._use_dynamic_batch_size[split]: - dataset = ( - dataset.batched(self._local_batch_size, - collation_fn=Collater(dataset=[], - follow_batch=None, - exclude_keys=None)) - .with_epoch(n_batches) - .with_length(n_batches) - ) - else: - f_batching = SizeAwareBatching( - max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, - size_fn=estimate_size, - ) - dataset = (dataset.compose(f_batching) - .with_epoch(n_batches) - .with_length(n_batches) - ) - if is_train: - dataset = dataset.select(lambda x: len(x) > 1) - return dataset def setup(self, stage: str) -> None: @@ -237,23 +212,42 @@ def setup(self, stage: str) -> None: raise NotImplementedError("Data setup with stage = {stage}\ is not implmented") - def _setup_dataloader(self, dataset : wds.WebDataset) -> wds.WebLoader: - """wrap the input dataset into a WebLoader + def _setup_dataloader(self, split : Split) -> wds.WebLoader: + """setup the dataloader for the input dataset split Args: - dataset (wds.WebDataset): input dataset object + split (Split): input split type Returns: WebLoader object """ - if not hasattr(dataset, "__len__"): - raise RuntimeError("Input dataset object doesn't have length") - n_batches = len(dataset) + dataset = self._dataset[split] + n_samples = len(self._names_subset[split]) + n_batches = ((n_samples + self._global_batch_size - 1) + // self._global_batch_size) loader = wds.WebLoader(dataset, num_workers=self._n_workers_dataloader, pin_memory=self._pin_memory_dataloader, - collate_fn=lambda x: x[0], - ).with_length(n_batches).with_epoch(n_batches) + batch_size=None + ).shuffle(5000, rng=random.Random(self._seed_rng_shfl)) + + if not self._use_dynamic_batch_size[split]: + loader = loader.batched( + self._local_batch_size, collation_fn=Collater(dataset=[], + follow_batch=None, + exclude_keys=None) + ) + else: + f_batching = SizeAwareBatching( + max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, + size_fn=estimate_size, + ) + loader = loader.compose(f_batching) + + if split == Split.train: + loader = loader.select(lambda x: len(x) > 1) + + loader = loader.with_epoch(n_batches) # strange features required by nemo optimizer lr_scheduler loader.dataset = dataset # seems like only length is used, webloader doesn't have this attr @@ -265,19 +259,19 @@ def _setup_dataloader(self, dataset : wds.WebDataset) -> wds.WebLoader: def train_dataloader(self) -> wds.WebLoader: assert self._dataset[Split.train] is not None,\ f"dataset for train has not been setup" - return self._setup_dataloader(self._dataset[Split.train]) + return self._setup_dataloader(Split.train) def val_dataloader(self) -> wds.WebLoader: assert self._dataset[Split.val] is not None,\ f"dataset for val has not been setup" - return self._setup_dataloader(self._dataset[Split.val]) + return self._setup_dataloader(Split.val) def test_dataloader(self) -> wds.WebLoader: assert self._dataset[Split.test] is not None,\ f"dataset for test has not been setup" - return self._setup_dataloader(self._dataset[Split.test]) + return self._setup_dataloader(Split.test) def predict_dataloader(self) -> wds.WebLoader: diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 1ae7f4a18e..f5460d095a 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -18,6 +18,7 @@ import torch import lightning +from torch_geometric.data import HeteroData from bionemo.contrib.data.molecule.diffdock.datamodule import Split @@ -68,10 +69,6 @@ def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another names = [] pos_ligand = [] for sample in m._dataset[split]: - if isinstance(sample, list): - assert len(sample) == 1,\ - "Uncollated sample batch returned as list" - sample = sample[0] names.append(sample.name) pos_ligand.append(sample["ligand"].pos) lists_complex_name.append(names) From 69620fbd05718955775e6ae0e79ec223e7a25fcc Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 03:19:58 +0000 Subject: [PATCH 19/70] Test: assert sample and batch type --- .../contrib/data/test_diffdock_datamodule.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index f5460d095a..9a0a50fe14 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -14,11 +14,12 @@ # limitations under the License. import glob +from numpy import isin import pytest import torch import lightning -from torch_geometric.data import HeteroData +from torch_geometric.data import Batch, HeteroData from bionemo.contrib.data.molecule.diffdock.datamodule import Split @@ -69,6 +70,8 @@ def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another names = [] pos_ligand = [] for sample in m._dataset[split]: + assert isinstance(sample, HeteroData),\ + "Sample yield from dataset is not PyG HeteroData" names.append(sample.name) pos_ligand.append(sample["ligand"].pos) lists_complex_name.append(names) @@ -120,13 +123,14 @@ def test_ScoreModelWDS_setup_dataloader(split, create_ScoreModelWDS, create_anot else: raise RuntimeError(f"Test for split {split} not implemented") assert loader is not None, "dataloader not instantated" - for sample in loader: - if isinstance(sample, list): - assert len(sample) == 1,\ - "Uncollated sample batch returned as list" - sample = sample[0] - names.append(sample.name) - pos_ligand.append(sample["ligand"].pos) + for samples in loader: + # PyG's HeteroDataBatch is Batch inherited from HeteroData + assert isinstance(samples, Batch),\ + f"Sample object is not PyG Batch" + assert isinstance(samples, HeteroData),\ + f"Sample object is not PyG HeteroData" + names.append(samples.name) + pos_ligand.append(samples["ligand"].pos) lists_complex_name.append(names) lists_pos_ligand.append(pos_ligand) From 6ee03c081789d0288ac812ddcd428ca595a15258 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 03:59:57 +0000 Subject: [PATCH 20/70] Enhancement: factor out the batching method as input webloader composable --- .../data/molecule/diffdock/datamodule.py | 67 ++++++++----------- .../tests/bionemo/contrib/data/conftest.py | 26 ++++++- 2 files changed, 54 insertions(+), 39 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index c53fff393f..221d9096e0 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -17,15 +17,13 @@ import glob import pickle import random -from typing import Dict, Generator, List, Set, Optional, Tuple +from typing import Dict, Generator, List, Optional import lightning as L -import torch from torch_geometric.data.hetero_data import HeteroData -from torch_geometric.loader.dataloader import Collater import webdataset as wds from bionemo.contrib.data.molecule.diffdock.utils import ( - pickles_to_tars, SizeAwareBatching, estimate_size + pickles_to_tars ) @@ -44,13 +42,14 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, prefix_dir_tars_wds : str, names_subset : Dict[Split, List[str]], local_batch_size : int, global_batch_size : int, - n_workers_dataloader : int, xform_gen_wds : - Optional[Dict[Split, Generator[HeteroData, None, None]]] = - None, apply_size_control : Tuple[bool, bool, bool] = (True, - False, - False), - pin_memory_dataloader : bool = True, prefix_tars_wds : str = - "heterographs", n_tars_wds : Optional[int] = None, + n_workers_dataloader : int, + pipeline_wds : Optional[Dict[Split, Generator[HeteroData, None, + None]]] = None, + pipeline_prebatch_wld : Optional[Dict[Split, + Generator[HeteroData, + None, None]]] = + None, pin_memory_dataloader : bool = True, prefix_tars_wds : + str = "heterographs", n_tars_wds : Optional[int] = None, seed_rng_shfl : int = 0): """constructor @@ -75,12 +74,16 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, data loading time for shuffling Kwargs: - xform_gen_wds (Optional[Dict[Split, Generator[HeteroData, None, - None]]]): a dictionary of webdatast composable, i.e., functor that - maps a generator to another generator that transforms the data - sample, for different splits - apply_size_control (Tuple[bool, bool, bool]): whether to use - SizeAwareBatching for the respective train, val and test data + pipeline_wds (Optional[Dict[Split, Generator[HeteroData, None, + None]]]): a dictionary of webdatast composable, i.e., functor + that maps a generator to another generator that transforms the + data sample yield from the dataset object, for different splits + pipeline_prebatch_wld (Optional[Dict[Split, Generator[HeteroData, + None, None]]]): a dictionary of webloader composable, i.e., + functor that maps a generator to another generator that + transforms the data sample yield from the WebLoader object, for + different splits. NOTE: this is applied before batching is yield + from the WebLoader pin_memory_dataloader (bool): whether to use pin memory in pytorch dataloader prefix_tars_wds (str): name prefix to output webdataset tar files @@ -114,15 +117,11 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, Split.test : f"{self._prefix_dir_tars_wds}test", } - self._xform_gen_wds = xform_gen_wds + self._pipeline_wds = pipeline_wds + self._pipeline_prebatch_wld = pipeline_prebatch_wld self._local_batch_size = local_batch_size self._global_batch_size = global_batch_size - self._use_dynamic_batch_size = { - Split.train : apply_size_control[0], - Split.val : apply_size_control[1], - Split.test : apply_size_control[2], - } self._n_workers_dataloader = n_workers_dataloader self._pin_memory_dataloader = pin_memory_dataloader self._seed_rng_shfl = seed_rng_shfl @@ -186,9 +185,9 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: .decode() .extract_keys(f"*.{self._suffix_heterodata}") ) - if (self._xform_gen_wds is not None and - self._xform_gen_wds[split] is not None): - dataset = dataset.compose(self._xform_gen_wds[split]) + if (self._pipeline_wds is not None and + self._pipeline_wds[split] is not None): + dataset = dataset.compose(self._pipeline_wds[split]) if is_train: dataset = dataset.shuffle(size=16, rng=random.Random(self._seed_rng_shfl)) @@ -231,18 +230,10 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: batch_size=None ).shuffle(5000, rng=random.Random(self._seed_rng_shfl)) - if not self._use_dynamic_batch_size[split]: - loader = loader.batched( - self._local_batch_size, collation_fn=Collater(dataset=[], - follow_batch=None, - exclude_keys=None) - ) - else: - f_batching = SizeAwareBatching( - max_total_size=0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20, - size_fn=estimate_size, - ) - loader = loader.compose(f_batching) + if (self._pipeline_prebatch_wld is not None and + self._pipeline_prebatch_wld[split] is not None): + loader = loader.compose( + self._pipeline_prebatch_wld[split]) if split == Split.train: loader = loader.select(lambda x: len(x) > 1) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 8414e97597..d12ea78be5 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -16,9 +16,16 @@ import os import pytest from functools import partial +import torch +from torch_geometric.loader.data_list_loader import collate_fn +from torch_geometric.loader.dataloader import Collater +from webdataset.filters import batched from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( t_to_sigma, GenerateNoise) +from bionemo.contrib.data.molecule.diffdock.utils import ( + SizeAwareBatching, estimate_size + ) from bionemo.contrib.data.molecule.diffdock.datamodule import Split, ScoreModelWDS @@ -60,6 +67,7 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, sigma_t = partial(t_to_sigma, tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max) + # webdataset pipeline generateNoise = { Split.train : GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), @@ -70,13 +78,29 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, } local_batch_size = 2 global_batch_size = 2 + size_cuda_mem = (0.85 * + torch.cuda.get_device_properties("cuda:0").total_memory / + 2**20) + batch_pyg = batched(local_batch_size, + collation_fn=Collater(dataset=[], follow_batch=None, + exclude_keys=None)) + # WebLoader pipeline + pipelines_wdl_batch = { + Split.train : SizeAwareBatching( + max_total_size=size_cuda_mem, + size_fn=estimate_size), + Split.val : batch_pyg, + Split.test : batch_pyg, + } n_workers_dataloader = 2 n_tars_wds = 4 seed_rng_shfl = 822782392 data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, names, local_batch_size, global_batch_size, n_workers_dataloader, - generateNoise, n_tars_wds=n_tars_wds, + pipeline_wds=generateNoise, + pipeline_prebatch_wld=pipelines_wdl_batch, + n_tars_wds=n_tars_wds, seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds From 598a4aeee7a01f33f7c01435b89aa33258c69538 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 04:15:46 +0000 Subject: [PATCH 21/70] Enhancement: remove local_batch_size ... since the batching is expected to be passed as a wds pipeline method --- .../contrib/data/molecule/diffdock/datamodule.py | 15 +++++++++------ .../tests/bionemo/contrib/data/conftest.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 221d9096e0..4ac239892a 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -41,8 +41,7 @@ class ScoreModelWDS(L.LightningDataModule): def __init__(self, dir_heterodata : str, suffix_heterodata : str, prefix_dir_tars_wds : str, names_subset : Dict[Split, List[str]], - local_batch_size : int, global_batch_size : int, - n_workers_dataloader : int, + global_batch_size : int, n_workers_dataloader : int, pipeline_wds : Optional[Dict[Split, Generator[HeteroData, None, None]]] = None, pipeline_prebatch_wld : Optional[Dict[Split, @@ -65,9 +64,14 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, respectively. names_subset (Dict[Split, List[str]]): list of complex names to be included in each of the split - local_batch_size (int): size of batch for each node global_batch_size (int): size of batch summing across nodes in Data - Distributed Parallel, i.e., local_batch_size * n_nodes + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches n_workers_dataloader (int): number of data loading workers (passed to pytorch dataloader) seed_rng_shfl (int): seed to the random number generators used in @@ -120,7 +124,6 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, self._pipeline_wds = pipeline_wds self._pipeline_prebatch_wld = pipeline_prebatch_wld - self._local_batch_size = local_batch_size self._global_batch_size = global_batch_size self._n_workers_dataloader = n_workers_dataloader self._pin_memory_dataloader = pin_memory_dataloader @@ -242,7 +245,7 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: # strange features required by nemo optimizer lr_scheduler loader.dataset = dataset # seems like only length is used, webloader doesn't have this attr - loader.batch_size = self._local_batch_size + loader.batch_size = self._global_batch_size loader.drop_last = False return loader diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index d12ea78be5..37da7ee7ec 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -96,8 +96,8 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, n_tars_wds = 4 seed_rng_shfl = 822782392 data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, - prefix_dir_tars_wds, names, local_batch_size, - global_batch_size, n_workers_dataloader, + prefix_dir_tars_wds, names, global_batch_size, + n_workers_dataloader, pipeline_wds=generateNoise, pipeline_prebatch_wld=pipelines_wdl_batch, n_tars_wds=n_tars_wds, From 213095b856cfbb21c1de3a71159e63aa3e60437e Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 04:19:55 +0000 Subject: [PATCH 22/70] BugFix: remove unused random.seed() call --- .../src/bionemo/contrib/data/molecule/diffdock/datamodule.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 4ac239892a..6bbb15fc25 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -176,7 +176,6 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: Returns: WebDataset """ - random.seed(self._seed_rng_shfl) is_train = split == Split.train urls = sorted(glob.glob( f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") From 1d8f162b2bfcb0fa57b3edf5118b3a0d51fa1106 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 17:53:02 +0000 Subject: [PATCH 23/70] Enhancement: rename ScoreModelWDS -> PickledDataWDS --- .../data/molecule/diffdock/datamodule.py | 71 +++++++++---------- .../tests/bionemo/contrib/data/conftest.py | 7 +- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 6bbb15fc25..0a79ed54e5 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -17,7 +17,7 @@ import glob import pickle import random -from typing import Dict, Generator, List, Optional +from typing import Any, Dict, Generator, List, Optional import lightning as L from torch_geometric.data.hetero_data import HeteroData import webdataset as wds @@ -33,31 +33,30 @@ class Split(Enum): test = auto() -class ScoreModelWDS(L.LightningDataModule): +class PickledDataWDS(L.LightningDataModule): - """lightning APIs to process score model data and setup dataset and - dataloader""" + """lightning APIs to process pickled data into webdataset tar files and + setup dataset and dataloader""" - def __init__(self, dir_heterodata : str, suffix_heterodata : str, + def __init__(self, dir_pickled : str, suffix_pickled : str, prefix_dir_tars_wds : str, names_subset : Dict[Split, List[str]], global_batch_size : int, n_workers_dataloader : int, - pipeline_wds : Optional[Dict[Split, Generator[HeteroData, None, + pipeline_wds : Optional[Dict[Split, Generator[Any, None, None]]] = None, - pipeline_prebatch_wld : Optional[Dict[Split, - Generator[HeteroData, - None, None]]] = - None, pin_memory_dataloader : bool = True, prefix_tars_wds : - str = "heterographs", n_tars_wds : Optional[int] = None, + pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, + None, + None]]] + = None, pin_memory_dataloader : bool = True, prefix_tars_wds : + str = "wdshards", n_tars_wds : Optional[int] = None, seed_rng_shfl : int = 0): """constructor Args: - dir_heterodata (str): input directory of PyG HeteroData pickled - files - suffix_heterodata (str): filename suffix of the input data in - dir_heterodata. This is also used as the key mapped to the - tarballed HeteroData object in the webdataset + dir_pickled (str): input directory of pickled data files + suffix_pickled (str): filename suffix of the input data in + dir_pickled. This is also used as the key mapped to the + tarballed pickled object in the webdataset prefix_dir_tars_wds (str): directory name prefix to store the output webdataset tar files. The actual directories storing the train, val and test sets will be suffixed with "train", "val" and "test" @@ -74,31 +73,31 @@ def __init__(self, dir_heterodata : str, suffix_heterodata : str, global_batch_size batches n_workers_dataloader (int): number of data loading workers (passed to pytorch dataloader) - seed_rng_shfl (int): seed to the random number generators used in - data loading time for shuffling Kwargs: - pipeline_wds (Optional[Dict[Split, Generator[HeteroData, None, - None]]]): a dictionary of webdatast composable, i.e., functor + pipeline_wds (Optional[Dict[Split, Generator[Any, None, None]]]): a + dictionary of webdatast composable, i.e., functor that maps a + generator to another generator that transforms the data sample + yield from the dataset object, for different splits + pipeline_prebatch_wld (Optional[Dict[Split, Generator[Any, None, + None]]]): a dictionary of webloader composable, i.e., functor that maps a generator to another generator that transforms the - data sample yield from the dataset object, for different splits - pipeline_prebatch_wld (Optional[Dict[Split, Generator[HeteroData, - None, None]]]): a dictionary of webloader composable, i.e., - functor that maps a generator to another generator that - transforms the data sample yield from the WebLoader object, for - different splits. NOTE: this is applied before batching is yield - from the WebLoader - pin_memory_dataloader (bool): whether to use pin memory in pytorch - dataloader + data sample yield from the WebLoader object, for different + splits. NOTE: this is applied before batching is yield from the + WebLoader pin_memory_dataloader (bool): whether to use pin + memory in pytorch dataloader prefix_tars_wds (str): name prefix to output webdataset tar files - n_tars_wds (int): attempt to create at least this number of webdataset shards + n_tars_wds (int): attempt to create at least this number of + webdataset shards + seed_rng_shfl (int): seed to the random number generators used in + data loading time for shuffling """ super().__init__() - self._dir_heterodata = dir_heterodata - self._suffix_heterodata = suffix_heterodata + self._dir_pickled = dir_pickled + self._suffix_pickled = suffix_pickled self._n_tars_wds = n_tars_wds self._prefix_dir_tars_wds = prefix_dir_tars_wds self._prefix_tars_wds = prefix_tars_wds @@ -145,7 +144,7 @@ def _complex_graph_to_tar(self, complex_graph : HeteroData): """ return { "__key__": complex_graph.name.replace(".", "-"), - self._suffix_heterodata: pickle.dumps(complex_graph) + self._suffix_pickled: pickle.dumps(complex_graph) } @@ -158,8 +157,8 @@ def prepare_data(self) -> None: """ for split in self._names_subset.keys(): # create wds shards (tar files) for train set - pickles_to_tars(self._dir_heterodata, - self._suffix_heterodata, + pickles_to_tars(self._dir_pickled, + self._suffix_pickled, self._names_subset[split], self._dirs_tars_wds[split], self._prefix_tars_wds, @@ -185,7 +184,7 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: nodesplitter=wds.split_by_node, seed=self._seed_rng_shfl) .decode() - .extract_keys(f"*.{self._suffix_heterodata}") + .extract_keys(f"*.{self._suffix_pickled}") ) if (self._pipeline_wds is not None and self._pipeline_wds[split] is not None): diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 37da7ee7ec..e20e910804 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -26,7 +26,9 @@ from bionemo.contrib.data.molecule.diffdock.utils import ( SizeAwareBatching, estimate_size ) -from bionemo.contrib.data.molecule.diffdock.datamodule import Split, ScoreModelWDS +from bionemo.contrib.data.molecule.diffdock.datamodule import ( + Split, PickledDataWDS + ) @pytest.fixture(scope="module") @@ -95,11 +97,12 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, n_workers_dataloader = 2 n_tars_wds = 4 seed_rng_shfl = 822782392 - data_module = ScoreModelWDS(dir_heterodata, suffix_heterodata, + data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, prefix_dir_tars_wds, names, global_batch_size, n_workers_dataloader, pipeline_wds=generateNoise, pipeline_prebatch_wld=pipelines_wdl_batch, + prefix_tars_wds="heterographs", n_tars_wds=n_tars_wds, seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds From add78709ed01d2046810357d75032885b4b94faf Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 18:43:00 +0000 Subject: [PATCH 24/70] Enhancement: internalize the wds archive formatter the HeteroData-specific wds tar archive formatter is replaced by a generic version defaulted in pickles_to_tars(). Since the current data module archive one pickled data for each sample anyway, using the default is sufficient. Also add more doc for the default formatter usage --- .../data/molecule/diffdock/datamodule.py | 36 +++++++---------- .../contrib/data/molecule/diffdock/utils.py | 40 +++++++++++++++---- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 0a79ed54e5..94d5a7d930 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -15,11 +15,9 @@ from enum import Enum, auto import glob -import pickle import random from typing import Any, Dict, Generator, List, Optional import lightning as L -from torch_geometric.data.hetero_data import HeteroData import webdataset as wds from bionemo.contrib.data.molecule.diffdock.utils import ( @@ -36,7 +34,16 @@ class Split(Enum): class PickledDataWDS(L.LightningDataModule): """lightning APIs to process pickled data into webdataset tar files and - setup dataset and dataloader""" + setup dataset and dataloader. This data module takes a directory of pickled + data files, data filename prefixes for train/val/test splits, data filename + suffixes and prepare webdataset tar files by globbing the specific pickeld + data files {dir_pickled}/{name_subset[split]}.{suffix_pickled} and outputing + to webdataset tar file with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. NOTE: this + assumes only one pickled file is processed for each sample. In its setup() + function, it creates the webdataset object chaining up the input + `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the + WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" def __init__(self, dir_pickled : str, suffix_pickled : str, prefix_dir_tars_wds : str, names_subset : Dict[Split, @@ -132,26 +139,14 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, self._dataset = dict() - def _complex_graph_to_tar(self, complex_graph : HeteroData): - """map input complex graph to webdataset tar file conforming to its - format requirement - - Args: - complex_graph (HeteroData): input complex graph - - Returns: webdataset tar file segment (dict) - - """ - return { - "__key__": complex_graph.name.replace(".", "-"), - self._suffix_pickled: pickle.dumps(complex_graph) - } - - def prepare_data(self) -> None: """This is called only by the main process by the Lightning workflow. Do not rely on this data module object's state update here as there is no - way to communicate the state update to other subprocesses + way to communicate the state update to other subprocesses. The + `pickles_to_tars` function goes through the data name prefixes in the + different splits, read the corresponding pickled file and output a + webdataset tar archive with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. Returns: None """ @@ -162,7 +157,6 @@ def prepare_data(self) -> None: self._names_subset[split], self._dirs_tars_wds[split], self._prefix_tars_wds, - self._complex_graph_to_tar, min_num_shards=self._n_tars_wds) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index ef8359fed6..f532340202 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -16,7 +16,7 @@ import os import pickle import random -from typing import Any, Callable, Generator, List, Optional +from typing import Any, Dict, Callable, Generator, List, Optional from copy import deepcopy from nemo.utils import logging @@ -39,7 +39,9 @@ def pickles_to_tars( input_prefix_subset: List[str], dir_output: str, output_prefix: str, - func_output_data: Callable = lambda data: {"data": pickle.dumps(data)}, + func_output_data: Callable[[str, str, Any], Dict[str, Any]] = + lambda prefix, suffix, data: { "__key__": prefix, + suffix: pickle.dumps(data) }, min_num_shards: Optional[int] = None, ) -> None: """Convert a subset of pickle files from a directory to Webdataset tar files @@ -48,18 +50,37 @@ def pickles_to_tars( Output path and name pattern: f"{dir_output}/{output_prefix}-%06d.tar" + The webdataset tar archive is specified by the dictionary: + { + "__key__" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... + } + so that parsing the tar archive is equivalent of reading + {sample_filename_preifx}.{sample_filename_suffix_1} etc. + + Here, the assumption is that there is only one sample data file, whose name + prefix is given in each of the elements of `input_prefix_subset` and whose + name suffix is given by `input_suffix`. Per the webdataset file format + specification, the `sample_filename_preifx` can't contain dots '.' so this + function removes it for the user by calling .replace(".", "-") on the + elements of `input_prefix_subset` + Args: dir_input (str): Input directory input_suffix (str): Input pickle file name suffix input_prefix_subset (List[str]): Input subset of pickle files' prefix dir_output (str): Output directory output_prefix (str): Output tar file name prefix - func_output_data (Callable) : function that maps data to a dictionary - to be output in the tar files + func_output_data (Callable[[str, str, Any], Dict[str, Any]]) : function + that maps the name prefix, name suffix and data object to a + webdataset tar archive dictionary. Refer to the webdataset github + repo for the archive file format specification. min_num_shards (int) : create at least this number of tar files. - WebDataset has bugs when reading small number of tar files in a - multi-node lightening + DDP setting so this option can be used to - guarantee the tar file counts + WebDataset has bugs when reading small number of tar files in a + multi-node lightening + DDP setting so this option can be used to + guarantee the tar file counts Returns: None @@ -85,7 +106,10 @@ def pickles_to_tars( for name in input_prefix_subset: try: data = pickle.load(open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb")) - sample = func_output_data(data) + # the prefix name shouldn't contain any "." per webdataset's + # specification + sample = func_output_data(name.replace(".", "-"), + input_suffix, data) except ModuleNotFoundError as e: logging.error(f"Dependency for parsing input pickle data not "\ f"found: {e}") From 7bbc65173ca45f8d831e7880ccf0a39c815185a3 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 6 Aug 2024 18:58:21 +0000 Subject: [PATCH 25/70] Test: rename score model test data directory --- .../bionemo-contrib/tests/bionemo/contrib/data/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index e20e910804..abf3b28408 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -41,7 +41,8 @@ def get_path(request): @pytest.fixture(scope="module") def get_diffdock_score_model_heterodata(get_path): _, dir_data = get_path - dir_heterodata = f"{dir_data}/molecule/diffdock/heterodata" + dir_heterodata =\ + f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/score_model" suffix_heterodata = "heterodata.pyd" names = { Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", From 78c1f805885293a75e05da7b054f6eb3689eb526 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 7 Aug 2024 16:42:20 +0000 Subject: [PATCH 26/70] Test: confidence model data module --- .../contrib/data/molecule/diffdock/utils.py | 116 +++++++++++++- .../tests/bionemo/contrib/data/conftest.py | 96 +++++++++-- .../contrib/data/test_diffdock_datamodule.py | 150 ++++++++++++++++-- 3 files changed, 338 insertions(+), 24 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index f532340202..4df9bc335c 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -16,22 +16,21 @@ import os import pickle import random -from typing import Any, Dict, Callable, Generator, List, Optional -from copy import deepcopy +import math +from typing import ( + Any, Dict, Callable, Generator, List, Optional,Union, Iterable + ) +from omegaconf.listconfig import ListConfig from nemo.utils import logging import torch from torch_geometric.data import HeteroData from torch_geometric.data.batch import Batch from torch_geometric.loader.dataloader import Collater -from torch_geometric.transforms import BaseTransform import numpy as np import webdataset as wds -from bionemo.contrib.model.molecule.diffdock.utils.diffusion import modify_conformer, set_time -from bionemo.contrib.model.molecule.diffdock.utils import so3, torus - def pickles_to_tars( dir_input: str, @@ -209,3 +208,108 @@ def __call__(self, data: Batch) -> Generator[Batch, None, None]: batch = [sample] batch_size = sample_size + + +class SelectPoseAndLabelData: + """A WebDataset composable to select one ligand poses from multiple ones and + label confidence model training data by RMSD threshold""" + + def __init__( + self, + rmsd_classification_cutoff: Union[float, ListConfig], + samples_per_complex: int, + balance: bool, + all_atoms: bool, + seed : int = 0 + ): + """constructor + + Args: + rmsd_classification_cutoff (Union[float, ListConfig]): RMSD classification cutoff(s) + samples_per_complex (int): how many inference runs were done per complex + balance (bool): whether to do balance sampling + all_atoms (bool): whether the confidence model is all-atom + seed (int): random number generator seed + + Returns: + + """ + self.rmsd_classification_cutoff = rmsd_classification_cutoff + self.samples_per_complex = samples_per_complex + self.balance = balance + self.all_atoms = all_atoms + self._seed = seed + + def __call__(self, data: Iterable) -> Generator[HeteroData, None, None]: + """Map the input data iterator to another one that label the input data + + Args: + data (Iterable): Input data iterator + + Returns: + + """ + random.seed(self._seed) + for (complex_graph,) in data: + positions, rmsds = complex_graph.ligand_data + + if self.balance: + if isinstance(self.rmsd_classification_cutoff, ListConfig): + raise ValueError("a list for rmsd_classification_cutoff can only be used with balance=False") + # FIXME: should allow random.seed + label = random.randint(0, 1) + success = rmsds < self.rmsd_classification_cutoff + n_success = np.count_nonzero(success) + if label == 0 and n_success != self.samples_per_complex: + # sample negative complex + sample = random.randint(0, self.samples_per_complex - n_success - 1) + lig_pos = positions[~success][sample] + complex_graph["ligand"].pos = torch.from_numpy(lig_pos) + else: + # sample positive complex + if n_success > 0: # if no successful sample returns the matched complex + sample = random.randint(0, n_success - 1) + lig_pos = positions[success][sample] + complex_graph["ligand"].pos = torch.from_numpy(lig_pos) + complex_graph.y = torch.tensor(label).float() + else: + sample = random.randint(0, self.samples_per_complex - 1) + complex_graph["ligand"].pos = torch.from_numpy(positions[sample]) + ids = (rmsds[sample] < + self.rmsd_classification_cutoff).astype(int) + complex_graph.y = torch.tensor(ids).float().unsqueeze(0) + if isinstance(self.rmsd_classification_cutoff, ListConfig): + complex_graph.y_binned = torch.tensor( + np.logical_and( + rmsds[sample] < self.rmsd_classification_cutoff + [math.inf], + rmsds[sample] >= [0] + self.rmsd_classification_cutoff, + ), + dtype=torch.float, + ).unsqueeze(0) + complex_graph.y = ( + torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff[0]).unsqueeze(0).float() + ) + complex_graph.rmsd = torch.tensor(rmsds[sample]).unsqueeze(0).float() + + complex_graph["ligand"].node_t = { + "tr": 0 * torch.ones(complex_graph["ligand"].num_nodes), + "rot": 0 * torch.ones(complex_graph["ligand"].num_nodes), + "tor": 0 * torch.ones(complex_graph["ligand"].num_nodes), + } + complex_graph["receptor"].node_t = { + "tr": 0 * torch.ones(complex_graph["receptor"].num_nodes), + "rot": 0 * torch.ones(complex_graph["receptor"].num_nodes), + "tor": 0 * torch.ones(complex_graph["receptor"].num_nodes), + } + if self.all_atoms: + complex_graph["atom"].node_t = { + "tr": 0 * torch.ones(complex_graph["atom"].num_nodes), + "rot": 0 * torch.ones(complex_graph["atom"].num_nodes), + "tor": 0 * torch.ones(complex_graph["atom"].num_nodes), + } + complex_graph.complex_t = { + "tr": 0 * torch.ones(1), + "rot": 0 * torch.ones(1), + "tor": 0 * torch.ones(1), + } + yield complex_graph diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index abf3b28408..54f91605a5 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -24,7 +24,7 @@ from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( t_to_sigma, GenerateNoise) from bionemo.contrib.data.molecule.diffdock.utils import ( - SizeAwareBatching, estimate_size + SizeAwareBatching, estimate_size, SelectPoseAndLabelData ) from bionemo.contrib.data.molecule.diffdock.datamodule import ( Split, PickledDataWDS @@ -56,8 +56,26 @@ def get_diffdock_score_model_heterodata(get_path): return (dir_heterodata, suffix_heterodata, names) -def _create_ScoreModelWDS_impl(tmp_path_factory, - get_diffdock_score_model_heterodata): +@pytest.fixture(scope="module") +def get_diffdock_confidence_model_heterodata(get_path): + _, dir_data = get_path + dir_heterodata =\ + f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/confidence_model" + suffix_heterodata = "heterodata.pyd" + names = { + Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", + "7d5c", "7din", "7fha", "7jnb", "7k0v", "7kb1", "7km8", + "7l7c", "7lcu", "7msr", "7my1", "7n6f", "7np6"], + Split.val : ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", + "7rh3", "7rzl", "7sgv"], + Split.test : ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", + "7wpw", "7xek", "7xij"] + } + return (dir_heterodata, suffix_heterodata, names) + + +def _create_datamodule_score_model_impl(tmp_path_factory, + get_diffdock_score_model_heterodata): (dir_heterodata, suffix_heterodata, names) =\ get_diffdock_score_model_heterodata prefix_dir_tars_wds = tmp_path_factory.mktemp( @@ -110,12 +128,72 @@ def _create_ScoreModelWDS_impl(tmp_path_factory, @pytest.fixture(scope="module") -def create_ScoreModelWDS(tmp_path_factory, get_diffdock_score_model_heterodata): - return _create_ScoreModelWDS_impl(tmp_path_factory, - get_diffdock_score_model_heterodata) +def create_datamodule_score_model(tmp_path_factory, + get_diffdock_score_model_heterodata): + return _create_datamodule_score_model_impl(tmp_path_factory, + get_diffdock_score_model_heterodata) + + +@pytest.fixture(scope="module") +def create_another_datamodule_score_model(tmp_path_factory, + get_diffdock_score_model_heterodata): + return _create_datamodule_score_model_impl(tmp_path_factory, + get_diffdock_score_model_heterodata) + + +def _create_datamodule_confidence_model_impl(tmp_path_factory, + get_diffdock_confidence_model_heterodata): + (dir_heterodata, suffix_heterodata, names) =\ + get_diffdock_confidence_model_heterodata + prefix_dir_tars_wds = tmp_path_factory.mktemp( + "diffdock_confidence_model_tars_wds").as_posix() + # webdataset pipeline + rmsd_classification_cutoff = 2.0 + samples_per_complex = 7 + balance = False + is_all_atom = True + seed_rng_shfl = 822782392 + select_pose = SelectPoseAndLabelData(rmsd_classification_cutoff, + samples_per_complex, balance, + is_all_atom, seed=seed_rng_shfl) + pipeline_wds = { + Split.train : select_pose, + Split.val : select_pose, + Split.test : select_pose, + } + local_batch_size = 2 + global_batch_size = 2 + batch_pyg = batched(local_batch_size, + collation_fn=Collater(dataset=[], follow_batch=None, + exclude_keys=None)) + # WebLoader pipeline + pipelines_wdl_batch = { + Split.train : batch_pyg, + Split.val : batch_pyg, + Split.test : batch_pyg, + } + n_workers_dataloader = 2 + n_tars_wds = 4 + data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, + prefix_dir_tars_wds, names, global_batch_size, + n_workers_dataloader, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipelines_wdl_batch, + prefix_tars_wds="heterographs", + n_tars_wds=n_tars_wds, + seed_rng_shfl=seed_rng_shfl) + return data_module, prefix_dir_tars_wds + + +@pytest.fixture(scope="module") +def create_datamodule_confidence_model(tmp_path_factory, + get_diffdock_confidence_model_heterodata): + return _create_datamodule_confidence_model_impl(tmp_path_factory, + get_diffdock_confidence_model_heterodata) @pytest.fixture(scope="module") -def create_another_ScoreModelWDS(tmp_path_factory, get_diffdock_score_model_heterodata): - return _create_ScoreModelWDS_impl(tmp_path_factory, - get_diffdock_score_model_heterodata) +def create_another_datamodule_confidence_model(tmp_path_factory, + get_diffdock_confidence_model_heterodata): + return _create_datamodule_confidence_model_impl(tmp_path_factory, + get_diffdock_confidence_model_heterodata) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 9a0a50fe14..b52d30d819 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -25,11 +25,11 @@ @pytest.mark.parametrize("split", [s for s in Split]) -def test_ScoreModelWDS_init(split, get_diffdock_score_model_heterodata, - create_ScoreModelWDS): +def test_datamodule_score_model_init(split, get_diffdock_score_model_heterodata, + create_datamodule_score_model): name_split = str(split).split('.')[1] (_, _, names) = get_diffdock_score_model_heterodata - data_module, prefix_dir_tars_wds = create_ScoreModelWDS + data_module, prefix_dir_tars_wds = create_datamodule_score_model assert data_module._sizes[split] == len(names[split]),\ f"Wrong {split}-set size: expected {len(names[split])}"\ f"but got {data_module._sizes[split]}" @@ -40,8 +40,23 @@ def test_ScoreModelWDS_init(split, get_diffdock_score_model_heterodata, @pytest.mark.parametrize("split", [s for s in Split]) -def test_ScoreModelWDS_prepare_data(split, create_ScoreModelWDS): - data_module, _ = create_ScoreModelWDS +def test_datamodule_confidence_model_init(split, get_diffdock_confidence_model_heterodata, + create_datamodule_confidence_model): + name_split = str(split).split('.')[1] + (_, _, names) = get_diffdock_confidence_model_heterodata + data_module, prefix_dir_tars_wds = create_datamodule_confidence_model + assert data_module._sizes[split] == len(names[split]),\ + f"Wrong {split}-set size: expected {len(names[split])}"\ + f"but got {data_module._sizes[split]}" + assert data_module._dirs_tars_wds[split] ==\ + f"{prefix_dir_tars_wds}{name_split}",\ + f"Wrong tar files directory: expected {prefix_dir_tars_wds}{split} "\ + f"but got {data_module._dirs_tars_wds[split]}" + + +@pytest.mark.parametrize("split", [s for s in Split]) +def test_datamodule_score_model_prepare_data(split, create_datamodule_score_model): + data_module, _ = create_datamodule_score_model # LightningDataModule.prepare_data() is supposed to be called from the main # process in a Lightning-managed multi-process context so we can call it in # a single process @@ -54,10 +69,69 @@ def test_ScoreModelWDS_prepare_data(split, create_ScoreModelWDS): f"expected {data_module._n_tars_wds} "\ f"got {len(files_tars)}" +@pytest.mark.parametrize("split", [s for s in Split]) +def test_datamodule_confidence_model_prepare_data(split, create_datamodule_confidence_model): + data_module, _ = create_datamodule_confidence_model + # LightningDataModule.prepare_data() is supposed to be called from the main + # process in a Lightning-managed multi-process context so we can call it in + # a single process + data_module.prepare_data() + files_tars = sorted(glob.glob( + f"{data_module._dirs_tars_wds[split]}/"\ + f"{data_module._prefix_tars_wds}-*.tar")) + assert len(files_tars) >= data_module._n_tars_wds,\ + f"Wrong num of {split}-set tar files: "\ + f"expected {data_module._n_tars_wds} "\ + f"got {len(files_tars)}" + + +@pytest.mark.parametrize("split", [s for s in Split]) +def test_datamodule_score_model_setup_dataset(split, create_datamodule_score_model, create_another_datamodule_score_model): + data_modules= [create_datamodule_score_model[0], create_another_datamodule_score_model[0]] + lists_complex_name = [] + lists_pos_ligand = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + lightning.seed_everything(2823828) + names = [] + pos_ligand = [] + for sample in m._dataset[split]: + assert isinstance(sample, HeteroData),\ + "Sample yield from dataset is not PyG HeteroData" + names.append(sample.name) + pos_ligand.append(sample["ligand"].pos) + lists_complex_name.append(names) + lists_pos_ligand.append(pos_ligand) + + assert len(lists_complex_name[0]) > 0,\ + "No names in {split} dataset" + assert lists_complex_name[0] == lists_complex_name[1],\ + f"Inconsistent sample name in {split}-set from data module instances: "\ + f"{lists_complex_name[0]} \n\nvs.\n\n"\ + f"{lists_complex_name[1]}" + + assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ + f"Inconsistent number of ligand position in {split}-set from data "\ + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + f"{len(lists_pos_ligand[1])}" + for i in range(len(lists_pos_ligand[0])): + pos_0 = lists_pos_ligand[0][i] + pos_1 = lists_pos_ligand[1][i] + torch.testing.assert_close(pos_0, pos_1, + msg=lambda m : + f"Inconsistent ligand position in the " + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}") + @pytest.mark.parametrize("split", [s for s in Split]) -def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another_ScoreModelWDS): - data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] +def test_datamodule_confidence_model_setup_dataset(split, create_datamodule_confidence_model, create_another_datamodule_confidence_model): + data_modules= [create_datamodule_confidence_model[0], create_another_datamodule_confidence_model[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: @@ -100,8 +174,66 @@ def test_ScoreModelWDS_setup_dataset(split, create_ScoreModelWDS, create_another @pytest.mark.parametrize("split", [s for s in Split]) -def test_ScoreModelWDS_setup_dataloader(split, create_ScoreModelWDS, create_another_ScoreModelWDS): - data_modules= [create_ScoreModelWDS[0], create_another_ScoreModelWDS[0]] +def test_datamodule_score_model_setup_dataloader(split, create_datamodule_score_model, create_another_datamodule_score_model): + data_modules= [create_datamodule_score_model[0], create_another_datamodule_score_model[0]] + lists_complex_name = [] + lists_pos_ligand = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + lightning.seed_everything(2823828) + names = [] + pos_ligand = [] + loader = None + if split == Split.train: + loader = m.train_dataloader() + elif split == Split.val: + loader = m.val_dataloader() + elif split == Split.test: + loader = m.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + for samples in loader: + # PyG's HeteroDataBatch is Batch inherited from HeteroData + assert isinstance(samples, Batch),\ + f"Sample object is not PyG Batch" + assert isinstance(samples, HeteroData),\ + f"Sample object is not PyG HeteroData" + names.append(samples.name) + pos_ligand.append(samples["ligand"].pos) + lists_complex_name.append(names) + lists_pos_ligand.append(pos_ligand) + + assert len(lists_complex_name[0]) > 0,\ + "No names in {split} dataloader" + assert lists_complex_name[0] == lists_complex_name[1],\ + f"Inconsistent sample name in {split}-set from data module instances: "\ + f"{lists_complex_name[0]} \n\nvs.\n\n"\ + f"{lists_complex_name[1]}" + + assert len(lists_pos_ligand[0]) > 0,\ + "No ligand position found in dataloader" + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ + f"Inconsistent number of ligand position in {split}-set from data "\ + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + f"{len(lists_pos_ligand[1])}" + for i in range(len(lists_pos_ligand[0])): + pos_0 = lists_pos_ligand[0][i] + pos_1 = lists_pos_ligand[1][i] + torch.testing.assert_close(pos_0, pos_1, + msg=lambda m : + f"Inconsistent ligand position in the " + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}") + + +@pytest.mark.parametrize("split", [s for s in Split]) +def test_datamodule_confidence_model_setup_dataloader(split, create_datamodule_confidence_model, create_another_datamodule_confidence_model): + data_modules= [create_datamodule_confidence_model[0], create_another_datamodule_confidence_model[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: From e13b325071a35d68fe700c4e6bdc6dbbea669506 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 7 Aug 2024 22:38:33 +0000 Subject: [PATCH 27/70] Test: parametrize tests with model type ... so to reduce the amount of boilerplate code in tests --- .../contrib/data/molecule/diffdock/utils.py | 2 +- .../tests/bionemo/contrib/data/conftest.py | 90 +++++----- .../contrib/data/test_diffdock_datamodule.py | 157 ++---------------- 3 files changed, 53 insertions(+), 196 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index 4df9bc335c..9ca329d596 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -115,7 +115,7 @@ def pickles_to_tars( raise e except Exception as e: logging.error(f"Failed to write {name} into tar files due to error {e}") - continue + raise e sink.write(sample) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 54f91605a5..11abc1ee33 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum, auto import os import pytest from functools import partial @@ -38,29 +39,18 @@ def get_path(request): return dir_test, dir_data -@pytest.fixture(scope="module") -def get_diffdock_score_model_heterodata(get_path): - _, dir_data = get_path - dir_heterodata =\ - f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/score_model" - suffix_heterodata = "heterodata.pyd" - names = { - Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", - "7d5c", "7din", "7fha", "7jnb", "7k0v", "7kb1", "7km8", - "7l7c", "7lcu", "7msr", "7my1", "7n6f", "7np6"], - Split.val : ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", - "7rh3", "7rzl", "7sgv"], - Split.test : ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", - "7wpw", "7xek", "7xij"] - } - return (dir_heterodata, suffix_heterodata, names) +class DiffDockModel(Enum): + score = auto() + confidence = auto() -@pytest.fixture(scope="module") -def get_diffdock_confidence_model_heterodata(get_path): +@pytest.fixture(scope="module", params=[m for m in DiffDockModel]) +def get_diffdock_heterodata(get_path, request): _, dir_data = get_path + model = request.param + name_model = str(model).split(".")[-1] dir_heterodata =\ - f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/confidence_model" + f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/{name_model}_model" suffix_heterodata = "heterodata.pyd" names = { Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", @@ -71,15 +61,13 @@ def get_diffdock_confidence_model_heterodata(get_path): Split.test : ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", "7wpw", "7xek", "7xij"] } - return (dir_heterodata, suffix_heterodata, names) + return (dir_heterodata, suffix_heterodata, names, model) -def _create_datamodule_score_model_impl(tmp_path_factory, - get_diffdock_score_model_heterodata): - (dir_heterodata, suffix_heterodata, names) =\ - get_diffdock_score_model_heterodata +def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, + suffix_heterodata, names): prefix_dir_tars_wds = tmp_path_factory.mktemp( - "diffdock_score_model_tars_wds").as_posix() + f"diffdock_score_model_tars_wds").as_posix() tr_sigma_min, tr_sigma_max = (0.1, 19) rot_sigma_min, rot_sigma_max = (0.03, 1.55) tor_sigma_min, tor_sigma_max = (0.0314, 3.14) @@ -127,24 +115,8 @@ def _create_datamodule_score_model_impl(tmp_path_factory, return data_module, prefix_dir_tars_wds -@pytest.fixture(scope="module") -def create_datamodule_score_model(tmp_path_factory, - get_diffdock_score_model_heterodata): - return _create_datamodule_score_model_impl(tmp_path_factory, - get_diffdock_score_model_heterodata) - - -@pytest.fixture(scope="module") -def create_another_datamodule_score_model(tmp_path_factory, - get_diffdock_score_model_heterodata): - return _create_datamodule_score_model_impl(tmp_path_factory, - get_diffdock_score_model_heterodata) - - -def _create_datamodule_confidence_model_impl(tmp_path_factory, - get_diffdock_confidence_model_heterodata): - (dir_heterodata, suffix_heterodata, names) =\ - get_diffdock_confidence_model_heterodata +def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, + suffix_heterodata, names): prefix_dir_tars_wds = tmp_path_factory.mktemp( "diffdock_confidence_model_tars_wds").as_posix() # webdataset pipeline @@ -186,14 +158,30 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, @pytest.fixture(scope="module") -def create_datamodule_confidence_model(tmp_path_factory, - get_diffdock_confidence_model_heterodata): - return _create_datamodule_confidence_model_impl(tmp_path_factory, - get_diffdock_confidence_model_heterodata) +def create_datamodule(tmp_path_factory, get_diffdock_heterodata): + dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata + if model == DiffDockModel.score: + return _create_datamodule_score_model_impl(tmp_path_factory, + dir_heterodata, + suffix_heterodata, + names) + elif model == DiffDockModel.confidence: + return _create_datamodule_confidence_model_impl(tmp_path_factory, + dir_heterodata, + suffix_heterodata, + names) @pytest.fixture(scope="module") -def create_another_datamodule_confidence_model(tmp_path_factory, - get_diffdock_confidence_model_heterodata): - return _create_datamodule_confidence_model_impl(tmp_path_factory, - get_diffdock_confidence_model_heterodata) +def create_another_datamodule(tmp_path_factory, get_diffdock_heterodata): + dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata + if model == DiffDockModel.score: + return _create_datamodule_score_model_impl(tmp_path_factory, + dir_heterodata, + suffix_heterodata, + names) + elif model == DiffDockModel.confidence: + return _create_datamodule_confidence_model_impl(tmp_path_factory, + dir_heterodata, + suffix_heterodata, + names) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index b52d30d819..9e5d6293a3 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -25,53 +25,24 @@ @pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_score_model_init(split, get_diffdock_score_model_heterodata, - create_datamodule_score_model): +def test_datamodule_init(split, get_diffdock_heterodata, create_datamodule): name_split = str(split).split('.')[1] - (_, _, names) = get_diffdock_score_model_heterodata - data_module, prefix_dir_tars_wds = create_datamodule_score_model + (_, _, names, model) = get_diffdock_heterodata + data_module, prefix_dir_tars_wds = create_datamodule assert data_module._sizes[split] == len(names[split]),\ - f"Wrong {split}-set size: expected {len(names[split])}"\ + f"Wrong {split}-set size for {model} model: "\ + f"expected {len(names[split])} "\ f"but got {data_module._sizes[split]}" assert data_module._dirs_tars_wds[split] ==\ f"{prefix_dir_tars_wds}{name_split}",\ - f"Wrong tar files directory: expected {prefix_dir_tars_wds}{split} "\ + f"Wrong tar files directory for {model} model: "\ + f"expected {prefix_dir_tars_wds}{split} "\ f"but got {data_module._dirs_tars_wds[split]}" @pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_confidence_model_init(split, get_diffdock_confidence_model_heterodata, - create_datamodule_confidence_model): - name_split = str(split).split('.')[1] - (_, _, names) = get_diffdock_confidence_model_heterodata - data_module, prefix_dir_tars_wds = create_datamodule_confidence_model - assert data_module._sizes[split] == len(names[split]),\ - f"Wrong {split}-set size: expected {len(names[split])}"\ - f"but got {data_module._sizes[split]}" - assert data_module._dirs_tars_wds[split] ==\ - f"{prefix_dir_tars_wds}{name_split}",\ - f"Wrong tar files directory: expected {prefix_dir_tars_wds}{split} "\ - f"but got {data_module._dirs_tars_wds[split]}" - - -@pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_score_model_prepare_data(split, create_datamodule_score_model): - data_module, _ = create_datamodule_score_model - # LightningDataModule.prepare_data() is supposed to be called from the main - # process in a Lightning-managed multi-process context so we can call it in - # a single process - data_module.prepare_data() - files_tars = sorted(glob.glob( - f"{data_module._dirs_tars_wds[split]}/"\ - f"{data_module._prefix_tars_wds}-*.tar")) - assert len(files_tars) >= data_module._n_tars_wds,\ - f"Wrong num of {split}-set tar files: "\ - f"expected {data_module._n_tars_wds} "\ - f"got {len(files_tars)}" - -@pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_confidence_model_prepare_data(split, create_datamodule_confidence_model): - data_module, _ = create_datamodule_confidence_model +def test_datamodule_prepare_data(split, create_datamodule): + data_module, _ = create_datamodule # LightningDataModule.prepare_data() is supposed to be called from the main # process in a Lightning-managed multi-process context so we can call it in # a single process @@ -86,52 +57,8 @@ def test_datamodule_confidence_model_prepare_data(split, create_datamodule_confi @pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_score_model_setup_dataset(split, create_datamodule_score_model, create_another_datamodule_score_model): - data_modules= [create_datamodule_score_model[0], create_another_datamodule_score_model[0]] - lists_complex_name = [] - lists_pos_ligand = [] - for m in data_modules: - m.prepare_data() - # run through all the possible stages first to setup all the correps. - # dataset objects - m.setup("fit") - m.setup("test") - lightning.seed_everything(2823828) - names = [] - pos_ligand = [] - for sample in m._dataset[split]: - assert isinstance(sample, HeteroData),\ - "Sample yield from dataset is not PyG HeteroData" - names.append(sample.name) - pos_ligand.append(sample["ligand"].pos) - lists_complex_name.append(names) - lists_pos_ligand.append(pos_ligand) - - assert len(lists_complex_name[0]) > 0,\ - "No names in {split} dataset" - assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent sample name in {split}-set from data module instances: "\ - f"{lists_complex_name[0]} \n\nvs.\n\n"\ - f"{lists_complex_name[1]}" - - assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ - f"Inconsistent number of ligand position in {split}-set from data "\ - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ - f"{len(lists_pos_ligand[1])}" - for i in range(len(lists_pos_ligand[0])): - pos_0 = lists_pos_ligand[0][i] - pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close(pos_0, pos_1, - msg=lambda m : - f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}") - - -@pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_confidence_model_setup_dataset(split, create_datamodule_confidence_model, create_another_datamodule_confidence_model): - data_modules= [create_datamodule_confidence_model[0], create_another_datamodule_confidence_model[0]] +def test_datamodule_setup_dataset(split, create_datamodule, create_another_datamodule): + data_modules= [create_datamodule[0], create_another_datamodule[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: @@ -174,66 +101,8 @@ def test_datamodule_confidence_model_setup_dataset(split, create_datamodule_conf @pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_score_model_setup_dataloader(split, create_datamodule_score_model, create_another_datamodule_score_model): - data_modules= [create_datamodule_score_model[0], create_another_datamodule_score_model[0]] - lists_complex_name = [] - lists_pos_ligand = [] - for m in data_modules: - m.prepare_data() - # run through all the possible stages first to setup all the correps. - # dataset objects - m.setup("fit") - m.setup("test") - lightning.seed_everything(2823828) - names = [] - pos_ligand = [] - loader = None - if split == Split.train: - loader = m.train_dataloader() - elif split == Split.val: - loader = m.val_dataloader() - elif split == Split.test: - loader = m.test_dataloader() - else: - raise RuntimeError(f"Test for split {split} not implemented") - assert loader is not None, "dataloader not instantated" - for samples in loader: - # PyG's HeteroDataBatch is Batch inherited from HeteroData - assert isinstance(samples, Batch),\ - f"Sample object is not PyG Batch" - assert isinstance(samples, HeteroData),\ - f"Sample object is not PyG HeteroData" - names.append(samples.name) - pos_ligand.append(samples["ligand"].pos) - lists_complex_name.append(names) - lists_pos_ligand.append(pos_ligand) - - assert len(lists_complex_name[0]) > 0,\ - "No names in {split} dataloader" - assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent sample name in {split}-set from data module instances: "\ - f"{lists_complex_name[0]} \n\nvs.\n\n"\ - f"{lists_complex_name[1]}" - - assert len(lists_pos_ligand[0]) > 0,\ - "No ligand position found in dataloader" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ - f"Inconsistent number of ligand position in {split}-set from data "\ - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ - f"{len(lists_pos_ligand[1])}" - for i in range(len(lists_pos_ligand[0])): - pos_0 = lists_pos_ligand[0][i] - pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close(pos_0, pos_1, - msg=lambda m : - f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}") - - -@pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_confidence_model_setup_dataloader(split, create_datamodule_confidence_model, create_another_datamodule_confidence_model): - data_modules= [create_datamodule_confidence_model[0], create_another_datamodule_confidence_model[0]] +def test_datamodule_setup_dataloader(split, create_datamodule, create_another_datamodule): + data_modules= [create_datamodule[0], create_another_datamodule[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: From c8a128b8167008e15186217f45ddf1a82bd92092 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 8 Aug 2024 00:37:25 +0000 Subject: [PATCH 28/70] Enhancement: factor out the webdataset-specific method into a base data module class ... so that the pickled file processing PickledDataWDS class inherits from it --- .../data/molecule/diffdock/datamodule.py | 255 ++++++++++++------ .../tests/bionemo/contrib/data/conftest.py | 16 +- .../contrib/data/test_diffdock_datamodule.py | 5 +- 3 files changed, 185 insertions(+), 91 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 94d5a7d930..50ac681dba 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -16,7 +16,7 @@ from enum import Enum, auto import glob import random -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, Generator, Iterable, List, Optional import lightning as L import webdataset as wds @@ -31,45 +31,33 @@ class Split(Enum): test = auto() -class PickledDataWDS(L.LightningDataModule): +class WDSModule(L.LightningDataModule): - """lightning APIs to process pickled data into webdataset tar files and - setup dataset and dataloader. This data module takes a directory of pickled - data files, data filename prefixes for train/val/test splits, data filename - suffixes and prepare webdataset tar files by globbing the specific pickeld - data files {dir_pickled}/{name_subset[split]}.{suffix_pickled} and outputing - to webdataset tar file with the dict structure: {"__key__" : - name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. NOTE: this - assumes only one pickled file is processed for each sample. In its setup() - function, it creates the webdataset object chaining up the input - `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the - WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" + """lightning data module for using webdataset tar files to setup dataset and + dataloader. This data module takes a dictionary: Split -> tar file + directory. In its setup() function, it creates the webdataset object + chaining up the input `pipeline_wds` workflow. In its + train/val/test_dataloader(), it creates the WebLoader object chaining up the + `pipeline_prebatch_wld` workflow""" - def __init__(self, dir_pickled : str, suffix_pickled : str, - prefix_dir_tars_wds : str, names_subset : Dict[Split, - List[str]], - global_batch_size : int, n_workers_dataloader : int, + def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, + int], + global_batch_size : int, suffix_keys_wds : Iterable[str], + prefix_tars_wds : str = "wdshards", pipeline_wds : Optional[Dict[Split, Generator[Any, None, None]]] = None, pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, None, None]]] - = None, pin_memory_dataloader : bool = True, prefix_tars_wds : - str = "wdshards", n_tars_wds : Optional[int] = None, - seed_rng_shfl : int = 0): + = None, n_workers_dataloader : int = 0, pin_memory_dataloader : + bool = True, seed_rng_shfl : int = 0): """constructor Args: - dir_pickled (str): input directory of pickled data files - suffix_pickled (str): filename suffix of the input data in - dir_pickled. This is also used as the key mapped to the - tarballed pickled object in the webdataset - prefix_dir_tars_wds (str): directory name prefix to store the output - webdataset tar files. The actual directories storing the train, val - and test sets will be suffixed with "train", "val" and "test" - respectively. - names_subset (Dict[Split, List[str]]): list of complex names to be - included in each of the split + dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split + n_samples (Dict[Split, int]): input dictionary: Split -> number of + data samples for each split global_batch_size (int): size of batch summing across nodes in Data Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: this data module doesn't rely on the input `global_batch_size` @@ -78,24 +66,31 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, is only used to compute a (pseudo-) epoch length for the data loader so that the loader yield approximately n_samples // global_batch_size batches - n_workers_dataloader (int): number of data loading workers (passed - to pytorch dataloader) - + suffix_keys_wds (Iterable): a set of keys each corresponding to a + data object in the webdataset tar file dictionary. The data + objects of these keys will be extracted and tupled for each + sample in the tar files Kwargs: + prefix_tars_wds (str): name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" pipeline_wds (Optional[Dict[Split, Generator[Any, None, None]]]): a dictionary of webdatast composable, i.e., functor that maps a generator to another generator that transforms the data sample - yield from the dataset object, for different splits + yield from the dataset object, for different splits. For + example, this can be used to transform the sample in the worker + before sending it to the main process of the dataloader pipeline_prebatch_wld (Optional[Dict[Split, Generator[Any, None, None]]]): a dictionary of webloader composable, i.e., functor that maps a generator to another generator that transforms the data sample yield from the WebLoader object, for different - splits. NOTE: this is applied before batching is yield from the - WebLoader pin_memory_dataloader (bool): whether to use pin - memory in pytorch dataloader - prefix_tars_wds (str): name prefix to output webdataset tar files - n_tars_wds (int): attempt to create at least this number of - webdataset shards + splits. For example, this can be used for batching the samples. + NOTE: this is applied before batching is yield from the + WebLoader + n_workers_dataloader (int): number of data loading workers (passed + to pytorch dataloader) + pin_memory_dataloader (bool): whether to use pin memory in pytorch + dataloader seed_rng_shfl (int): seed to the random number generators used in data loading time for shuffling @@ -103,34 +98,28 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, """ super().__init__() - self._dir_pickled = dir_pickled - self._suffix_pickled = suffix_pickled - self._n_tars_wds = n_tars_wds - self._prefix_dir_tars_wds = prefix_dir_tars_wds - self._prefix_tars_wds = prefix_tars_wds + self._dirs_tars_wds = dirs_tars_wds - keys_subset = names_subset.keys() + keys_subset = self._dirs_tars_wds.keys() if not (Split.train in keys_subset and Split.val in keys_subset): - raise RuntimeError("Input names_subset must be defined for the "\ + raise RuntimeError("Input dirs_tars_wds must be defined for the "\ "train and val splits") - self._names_subset = names_subset + if n_samples.keys() != keys_subset: + raise RuntimeError(f"Input n_samples has different keys than " + f"dirs_tars_wds: {n_samples.keys()} vs " + f"{keys_subset}" + ) - self._sizes = { - split : len(self._names_subset[split]) for split in - self._names_subset.keys() - } + self._n_samples= n_samples - self._dirs_tars_wds = { - Split.train : f"{self._prefix_dir_tars_wds}train", - Split.val : f"{self._prefix_dir_tars_wds}val", - Split.test : f"{self._prefix_dir_tars_wds}test", - } + self._global_batch_size = global_batch_size + self._suffix_keys_wds = suffix_keys_wds + self._prefix_tars_wds = prefix_tars_wds self._pipeline_wds = pipeline_wds self._pipeline_prebatch_wld = pipeline_prebatch_wld - self._global_batch_size = global_batch_size self._n_workers_dataloader = n_workers_dataloader self._pin_memory_dataloader = pin_memory_dataloader self._seed_rng_shfl = seed_rng_shfl @@ -142,23 +131,11 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, def prepare_data(self) -> None: """This is called only by the main process by the Lightning workflow. Do not rely on this data module object's state update here as there is no - way to communicate the state update to other subprocesses. The - `pickles_to_tars` function goes through the data name prefixes in the - different splits, read the corresponding pickled file and output a - webdataset tar archive with the dict structure: {"__key__" : - name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. + way to communicate the state update to other subprocesses. Returns: None """ - for split in self._names_subset.keys(): - # create wds shards (tar files) for train set - pickles_to_tars(self._dir_pickled, - self._suffix_pickled, - self._names_subset[split], - self._dirs_tars_wds[split], - self._prefix_tars_wds, - min_num_shards=self._n_tars_wds) - + pass def _setup_wds(self, split : Split) -> wds.WebDataset: """setup webdataset and webloader. This is called by setup() @@ -169,6 +146,9 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: Returns: WebDataset """ + if not split in self._dirs_tars_wds.keys(): + raise RuntimeError(f"_setup_wds() is called with {split} " + f"split that doesn't have the input tar dir") is_train = split == Split.train urls = sorted(glob.glob( f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") @@ -178,7 +158,7 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: nodesplitter=wds.split_by_node, seed=self._seed_rng_shfl) .decode() - .extract_keys(f"*.{self._suffix_pickled}") + .extract_keys(f"*.{self._suffix_keys_wds}") ) if (self._pipeline_wds is not None and self._pipeline_wds[split] is not None): @@ -215,8 +195,11 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: Returns: WebLoader object """ + if self._dataset[split] is None: + raise RuntimeError(f"_setup_dataloader() is called with {split} " + f"split without setting up the corresp. dataset") dataset = self._dataset[split] - n_samples = len(self._names_subset[split]) + n_samples = self._n_samples[split] n_batches = ((n_samples + self._global_batch_size - 1) // self._global_batch_size) loader = wds.WebLoader(dataset, @@ -243,22 +226,134 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: def train_dataloader(self) -> wds.WebLoader: - assert self._dataset[Split.train] is not None,\ - f"dataset for train has not been setup" return self._setup_dataloader(Split.train) def val_dataloader(self) -> wds.WebLoader: - assert self._dataset[Split.val] is not None,\ - f"dataset for val has not been setup" return self._setup_dataloader(Split.val) def test_dataloader(self) -> wds.WebLoader: - assert self._dataset[Split.test] is not None,\ - f"dataset for test has not been setup" return self._setup_dataloader(Split.test) def predict_dataloader(self) -> wds.WebLoader: raise NotImplementedError("predict dataloader not implemented") + + +class PickledDataWDS(WDSModule): + + """lightning APIs to process pickled data into webdataset tar files and + setup dataset and dataloader. This data module takes a directory of pickled + data files, data filename prefixes for train/val/test splits, data filename + suffixes and prepare webdataset tar files by globbing the specific pickeld + data files {dir_pickled}/{name_subset[split]}.{suffix_pickled} and outputing + to webdataset tar file with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. NOTE: this + assumes only one pickled file is processed for each sample. In its setup() + function, it creates the webdataset object chaining up the input + `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the + WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" + + def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : + Dict[Split, List[str]], prefix_dir_tars_wds : str, + global_batch_size : int, prefix_tars_wds : str = "wdshards", + n_tars_wds : Optional[int] = None, + pipeline_wds : Optional[Dict[Split, Generator[Any, None, + None]]] = None, + pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, + None, + None]]] + = None, n_workers_dataloader : int = 0, pin_memory_dataloader : + bool = True, seed_rng_shfl : int = 0): + """constructor + + Args: + dir_pickled (str): input directory of pickled data files + suffix_pickled (str): filename suffix of the input data in + dir_pickled. This is also used as the key mapped to the + tarballed pickled object in the webdataset + names_subset (Dict[Split, List[str]]): list of complex names to be + included in each of the split + prefix_dir_tars_wds (str): directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. + global_batch_size (int): size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + + Kwargs: + prefix_tars_wds (str): name prefix to output webdataset tar files + n_tars_wds (int): attempt to create at least this number of + webdataset shards + pipeline_wds (Optional[Dict[Split, Generator[Any, None, None]]]): a + dictionary of webdatast composable, i.e., functor that maps a + generator to another generator that transforms the data sample + yield from the dataset object, for different splits + pipeline_prebatch_wld (Optional[Dict[Split, Generator[Any, None, + None]]]): a dictionary of webloader composable, i.e., functor + that maps a generator to another generator that transforms the + data sample yield from the WebLoader object, for different + splits. NOTE: this is applied before batching is yield from the + WebLoader + n_workers_dataloader (int): number of data loading workers (passed + to pytorch dataloader) + pin_memory_dataloader (bool): whether to use pin memory in pytorch + dataloader + seed_rng_shfl (int): seed to the random number generators used in + data loading time for shuffling + + + """ + super().__init__( + { + split : f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" + for split in names_subset.keys() + }, + { + split : len(names_subset[split]) for split in + names_subset.keys() + }, + global_batch_size, + suffix_pickled, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + n_workers_dataloader=n_workers_dataloader, + pin_memory_dataloader=pin_memory_dataloader, + seed_rng_shfl=seed_rng_shfl + ) + + self._dir_pickled = dir_pickled + self._suffix_pickled = suffix_pickled + self._prefix_dir_tars_wds = prefix_dir_tars_wds + + self._names_subset = names_subset + + self._n_tars_wds = n_tars_wds + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses. The + `pickles_to_tars` function goes through the data name prefixes in the + different splits, read the corresponding pickled file and output a + webdataset tar archive with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. + + Returns: None + """ + for split in self._names_subset.keys(): + # create wds shards (tar files) for train set + pickles_to_tars(self._dir_pickled, + self._suffix_pickled, + self._names_subset[split], + self._dirs_tars_wds[split], + self._prefix_tars_wds, + min_num_shards=self._n_tars_wds) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 11abc1ee33..c933de4849 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -105,12 +105,12 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, n_tars_wds = 4 seed_rng_shfl = 822782392 data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, - prefix_dir_tars_wds, names, global_batch_size, - n_workers_dataloader, - pipeline_wds=generateNoise, - pipeline_prebatch_wld=pipelines_wdl_batch, + names, prefix_dir_tars_wds, global_batch_size, prefix_tars_wds="heterographs", n_tars_wds=n_tars_wds, + pipeline_wds=generateNoise, + pipeline_prebatch_wld=pipelines_wdl_batch, + n_workers_dataloader=n_workers_dataloader, seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds @@ -147,12 +147,12 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, n_workers_dataloader = 2 n_tars_wds = 4 data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, - prefix_dir_tars_wds, names, global_batch_size, - n_workers_dataloader, - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipelines_wdl_batch, + names, prefix_dir_tars_wds, global_batch_size, prefix_tars_wds="heterographs", n_tars_wds=n_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipelines_wdl_batch, + n_workers_dataloader=n_workers_dataloader, seed_rng_shfl=seed_rng_shfl) return data_module, prefix_dir_tars_wds diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index 9e5d6293a3..eb0adb6b8d 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -14,7 +14,6 @@ # limitations under the License. import glob -from numpy import isin import pytest import torch @@ -29,10 +28,10 @@ def test_datamodule_init(split, get_diffdock_heterodata, create_datamodule): name_split = str(split).split('.')[1] (_, _, names, model) = get_diffdock_heterodata data_module, prefix_dir_tars_wds = create_datamodule - assert data_module._sizes[split] == len(names[split]),\ + assert data_module._n_samples[split] == len(names[split]),\ f"Wrong {split}-set size for {model} model: "\ f"expected {len(names[split])} "\ - f"but got {data_module._sizes[split]}" + f"but got {data_module._n_samples[split]}" assert data_module._dirs_tars_wds[split] ==\ f"{prefix_dir_tars_wds}{name_split}",\ f"Wrong tar files directory for {model} model: "\ From 80149e20048eaba463fb97bc054748258a4cd978 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 8 Aug 2024 20:15:00 +0000 Subject: [PATCH 29/70] Enhancement: reduce args redundancy in PickledDataWDS by passing *args and **kwargs to the parent class --- .../data/molecule/diffdock/datamodule.py | 79 +++++-------------- .../tests/bionemo/contrib/data/conftest.py | 27 ++++--- 2 files changed, 38 insertions(+), 68 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 50ac681dba..e66e18a306 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -42,15 +42,15 @@ class WDSModule(L.LightningDataModule): def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, int], - global_batch_size : int, suffix_keys_wds : Iterable[str], + suffix_keys_wds : Iterable[str], global_batch_size : int, prefix_tars_wds : str = "wdshards", pipeline_wds : Optional[Dict[Split, Generator[Any, None, None]]] = None, pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, None, None]]] - = None, n_workers_dataloader : int = 0, pin_memory_dataloader : - bool = True, seed_rng_shfl : int = 0): + = None, seed_rng_shfl : int = 0, + kwargs_dl : Optional[Dict[Split, Dict[str, str]]] = None): """constructor Args: @@ -58,6 +58,10 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, directory that contains the webdataset tar files for each split n_samples (Dict[Split, int]): input dictionary: Split -> number of data samples for each split + suffix_keys_wds (Iterable): a set of keys each corresponding to a + data object in the webdataset tar file dictionary. The data + objects of these keys will be extracted and tupled for each + sample in the tar files global_batch_size (int): size of batch summing across nodes in Data Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: this data module doesn't rely on the input `global_batch_size` @@ -66,10 +70,6 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, is only used to compute a (pseudo-) epoch length for the data loader so that the loader yield approximately n_samples // global_batch_size batches - suffix_keys_wds (Iterable): a set of keys each corresponding to a - data object in the webdataset tar file dictionary. The data - objects of these keys will be extracted and tupled for each - sample in the tar files Kwargs: prefix_tars_wds (str): name prefix of the input webdataset tar files. The input tar files are globbed by @@ -87,12 +87,10 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, splits. For example, this can be used for batching the samples. NOTE: this is applied before batching is yield from the WebLoader - n_workers_dataloader (int): number of data loading workers (passed - to pytorch dataloader) - pin_memory_dataloader (bool): whether to use pin memory in pytorch - dataloader seed_rng_shfl (int): seed to the random number generators used in data loading time for shuffling + kwargs_dl (Optional[Dict[Split, Dict[str, str]]]): kwargs for data + loader, e.g., num_workers, of each split """ @@ -120,10 +118,10 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, self._pipeline_wds = pipeline_wds self._pipeline_prebatch_wld = pipeline_prebatch_wld - self._n_workers_dataloader = n_workers_dataloader - self._pin_memory_dataloader = pin_memory_dataloader self._seed_rng_shfl = seed_rng_shfl + self._kwargs_dl = kwargs_dl + # to be created later in setup self._dataset = dict() @@ -202,10 +200,9 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: n_samples = self._n_samples[split] n_batches = ((n_samples + self._global_batch_size - 1) // self._global_batch_size) - loader = wds.WebLoader(dataset, - num_workers=self._n_workers_dataloader, - pin_memory=self._pin_memory_dataloader, - batch_size=None + kwargs = self._kwargs_dl[split] if self._kwargs_dl is not None else None + loader = wds.WebLoader(dataset, batch_size=None, + **(kwargs if kwargs is not None else {}) ).shuffle(5000, rng=random.Random(self._seed_rng_shfl)) if (self._pipeline_prebatch_wld is not None and @@ -256,16 +253,8 @@ class PickledDataWDS(WDSModule): WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : - Dict[Split, List[str]], prefix_dir_tars_wds : str, - global_batch_size : int, prefix_tars_wds : str = "wdshards", - n_tars_wds : Optional[int] = None, - pipeline_wds : Optional[Dict[Split, Generator[Any, None, - None]]] = None, - pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, - None, - None]]] - = None, n_workers_dataloader : int = 0, pin_memory_dataloader : - bool = True, seed_rng_shfl : int = 0): + Dict[Split, List[str]], prefix_dir_tars_wds : str, *args, + n_tars_wds : Optional[int] = None, **kwargs): """constructor Args: @@ -279,35 +268,12 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : webdataset tar files. The actual directories storing the train, val and test sets will be suffixed with "train", "val" and "test" respectively. - global_batch_size (int): size of batch summing across nodes in Data - Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: - this data module doesn't rely on the input `global_batch_size` - for batching the samples. The batching is supposed to be done as - a part of the input `pipeline_prebatch_wld`. `global_batch_size` - is only used to compute a (pseudo-) epoch length for the data - loader so that the loader yield approximately n_samples // - global_batch_size batches + *args: arguments passed to the parent WDSModule Kwargs: - prefix_tars_wds (str): name prefix to output webdataset tar files n_tars_wds (int): attempt to create at least this number of webdataset shards - pipeline_wds (Optional[Dict[Split, Generator[Any, None, None]]]): a - dictionary of webdatast composable, i.e., functor that maps a - generator to another generator that transforms the data sample - yield from the dataset object, for different splits - pipeline_prebatch_wld (Optional[Dict[Split, Generator[Any, None, - None]]]): a dictionary of webloader composable, i.e., functor - that maps a generator to another generator that transforms the - data sample yield from the WebLoader object, for different - splits. NOTE: this is applied before batching is yield from the - WebLoader - n_workers_dataloader (int): number of data loading workers (passed - to pytorch dataloader) - pin_memory_dataloader (bool): whether to use pin memory in pytorch - dataloader - seed_rng_shfl (int): seed to the random number generators used in - data loading time for shuffling + **kwargs: arguments passed to the parent WDSModule """ @@ -320,14 +286,9 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : split : len(names_subset[split]) for split in names_subset.keys() }, - global_batch_size, suffix_pickled, - prefix_tars_wds=prefix_tars_wds, - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipeline_prebatch_wld, - n_workers_dataloader=n_workers_dataloader, - pin_memory_dataloader=pin_memory_dataloader, - seed_rng_shfl=seed_rng_shfl + *args, + **kwargs ) self._dir_pickled = dir_pickled diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index c933de4849..aaa7f8a374 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -18,7 +18,6 @@ import pytest from functools import partial import torch -from torch_geometric.loader.data_list_loader import collate_fn from torch_geometric.loader.dataloader import Collater from webdataset.filters import batched @@ -101,17 +100,22 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, Split.val : batch_pyg, Split.test : batch_pyg, } - n_workers_dataloader = 2 n_tars_wds = 4 seed_rng_shfl = 822782392 + kwargs_dl = { + Split.train : {'num_workers' : 2}, + Split.val : {'num_workers' : 2}, + Split.test : {'num_workers' : 2}, + } data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, names, prefix_dir_tars_wds, global_batch_size, - prefix_tars_wds="heterographs", n_tars_wds=n_tars_wds, + prefix_tars_wds="heterographs", pipeline_wds=generateNoise, pipeline_prebatch_wld=pipelines_wdl_batch, - n_workers_dataloader=n_workers_dataloader, - seed_rng_shfl=seed_rng_shfl) + seed_rng_shfl=seed_rng_shfl, + kwargs_dl=kwargs_dl + ) return data_module, prefix_dir_tars_wds @@ -144,16 +148,21 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, Split.val : batch_pyg, Split.test : batch_pyg, } - n_workers_dataloader = 2 n_tars_wds = 4 + kwargs_dl = { + Split.train : {'num_workers' : 2}, + Split.val : {'num_workers' : 2}, + Split.test : {'num_workers' : 2}, + } data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, names, prefix_dir_tars_wds, global_batch_size, - prefix_tars_wds="heterographs", n_tars_wds=n_tars_wds, + prefix_tars_wds="heterographs", pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, - n_workers_dataloader=n_workers_dataloader, - seed_rng_shfl=seed_rng_shfl) + seed_rng_shfl=seed_rng_shfl, + kwargs_dl=kwargs_dl + ) return data_module, prefix_dir_tars_wds From cc911b025866b643627ad8de36cadb94ca737717 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 8 Aug 2024 22:30:42 +0000 Subject: [PATCH 30/70] Enhancement: add integration tests in lightning trainer workflows --- .../data/molecule/diffdock/datamodule.py | 6 ++- .../tests/bionemo/contrib/data/conftest.py | 37 +++++++++++++++++++ .../contrib/data/test_diffdock_datamodule.py | 32 ++++++++++++++++ 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index e66e18a306..e2595c529d 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -178,11 +178,13 @@ def setup(self, stage: str) -> None: if stage == "fit": self._dataset[Split.train] = self._setup_wds(Split.train) self._dataset[Split.val] = self._setup_wds(Split.val) + elif stage == "validate": + self._dataset[Split.val] = self._setup_wds(Split.val) elif stage == "test": self._dataset[Split.test] = self._setup_wds(Split.test) else: - raise NotImplementedError("Data setup with stage = {stage}\ - is not implmented") + raise NotImplementedError(f"Data setup with stage = {stage} "\ + f"is not implmented") def _setup_dataloader(self, split : Split) -> wds.WebLoader: """setup the dataloader for the input dataset split diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index aaa7f8a374..6508780581 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -15,11 +15,13 @@ from enum import Enum, auto import os +from typing import Any import pytest from functools import partial import torch from torch_geometric.loader.dataloader import Collater from webdataset.filters import batched +import lightning as L from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( t_to_sigma, GenerateNoise) @@ -194,3 +196,38 @@ def create_another_datamodule(tmp_path_factory, get_diffdock_heterodata): dir_heterodata, suffix_heterodata, names) + + +class ModelTestDiffDock(L.LightningModule): + def __init__(self) -> None: + super().__init__() + self._model = torch.nn.Linear(3, 3) + self._samples = { split : [] for split in Split } + + + def forward(self, x): + return self._model(x["ligand"].pos) + + def training_step(self, batch): + self._samples[Split.train].append(batch.name) + loss = self(batch).sum() + return loss + + def validation_step(self, batch, batch_index): + self._samples[Split.val].append(batch.name) + return torch.zeros(1) + + def test_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) + return optimizer + + +@pytest.fixture(scope="function") +def create_trainer_and_model(): + trainer = L.Trainer(max_epochs=1, accelerator="gpu", + devices=1, val_check_interval=1) + model = ModelTestDiffDock() + return trainer, model diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index eb0adb6b8d..c26e04981f 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -155,3 +155,35 @@ def test_datamodule_setup_dataloader(split, create_datamodule, create_another_da f"Inconsistent ligand position in the " f"{i}'th sample/batch of {split}-set " f"between two data module instances:\n\n{m}") + + +@pytest.mark.parametrize("split", [s for s in Split]) +def test_datamodule_in_lightning(split, create_datamodule, + create_another_datamodule, + create_trainer_and_model): + data_modules= [create_datamodule[0], create_another_datamodule[0]] + trainer, model = create_trainer_and_model + # get the list of samples from the loader + lightning.seed_everything(2823828) + data_modules[0].prepare_data() + stage = None + if split == Split.train: + stage = "fit" + elif split == Split.val: + stage = "validate" + elif split == Split.test: + stage = "test" + else: + raise RuntimeError(f"{split} split not implemented") + data_modules[0].setup(stage) + # get the list of samples from the workflow + get_dataloader = getattr(data_modules[0], + f"{str(split).split('.')[-1]}_dataloader") + loader = get_dataloader() + samples = [] + for sample in loader: + samples.append(sample.name) + lightning.seed_everything(2823828) + workflow = getattr(trainer, stage) + workflow(model, data_modules[1]) + assert model._samples[split] == samples From 6541bc25ac47e0a29e6378d6af5daebc9e843c96 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 8 Aug 2024 23:40:32 +0000 Subject: [PATCH 31/70] Enhancement: implement predict stage --- .../data/molecule/diffdock/datamodule.py | 7 ++-- .../tests/bionemo/contrib/data/conftest.py | 5 ++- .../contrib/data/test_diffdock_datamodule.py | 33 ++++++++++++------- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index e2595c529d..3504a6868f 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -99,9 +99,6 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, self._dirs_tars_wds = dirs_tars_wds keys_subset = self._dirs_tars_wds.keys() - if not (Split.train in keys_subset and Split.val in keys_subset): - raise RuntimeError("Input dirs_tars_wds must be defined for the "\ - "train and val splits") if n_samples.keys() != keys_subset: raise RuntimeError(f"Input n_samples has different keys than " @@ -182,6 +179,8 @@ def setup(self, stage: str) -> None: self._dataset[Split.val] = self._setup_wds(Split.val) elif stage == "test": self._dataset[Split.test] = self._setup_wds(Split.test) + elif stage == "predict": + self._dataset[Split.test] = self._setup_wds(Split.test) else: raise NotImplementedError(f"Data setup with stage = {stage} "\ f"is not implmented") @@ -237,7 +236,7 @@ def test_dataloader(self) -> wds.WebLoader: def predict_dataloader(self) -> wds.WebLoader: - raise NotImplementedError("predict dataloader not implemented") + return self._setup_dataloader(Split.test) class PickledDataWDS(WDSModule): diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 6508780581..49a37c92c2 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -15,7 +15,6 @@ from enum import Enum, auto import os -from typing import Any import pytest from functools import partial import torch @@ -220,6 +219,10 @@ def validation_step(self, batch, batch_index): def test_step(self, batch, batch_index): self._samples[Split.test].append(batch.name) + def predict_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + return torch.zeros(1) + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) return optimizer diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py index c26e04981f..b7d4cb56d2 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum, auto import glob import pytest import torch @@ -157,8 +158,15 @@ def test_datamodule_setup_dataloader(split, create_datamodule, create_another_da f"between two data module instances:\n\n{m}") -@pytest.mark.parametrize("split", [s for s in Split]) -def test_datamodule_in_lightning(split, create_datamodule, +class Stage(Enum): + fit = auto() + validate = auto() + test = auto() + predict = auto() + + +@pytest.mark.parametrize("stage", [s for s in Stage]) +def test_datamodule_in_lightning(stage, create_datamodule, create_another_datamodule, create_trainer_and_model): data_modules= [create_datamodule[0], create_another_datamodule[0]] @@ -166,16 +174,17 @@ def test_datamodule_in_lightning(split, create_datamodule, # get the list of samples from the loader lightning.seed_everything(2823828) data_modules[0].prepare_data() - stage = None - if split == Split.train: - stage = "fit" - elif split == Split.val: - stage = "validate" - elif split == Split.test: - stage = "test" + split = None + if stage == Stage.fit: + split = Split.train + elif stage == Stage.validate: + split = Split.val + elif stage == Stage.test or stage == Stage.predict: + split = Split.test else: - raise RuntimeError(f"{split} split not implemented") - data_modules[0].setup(stage) + raise RuntimeError(f"{stage} stage not implemented") + name_stage = str(stage).split(".")[-1] + data_modules[0].setup(name_stage) # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") @@ -184,6 +193,6 @@ def test_datamodule_in_lightning(split, create_datamodule, for sample in loader: samples.append(sample.name) lightning.seed_everything(2823828) - workflow = getattr(trainer, stage) + workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) assert model._samples[split] == samples From bc009795dd94e1bc4c0d19d418673f3bb646ca3c Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 10 Aug 2024 00:42:31 +0000 Subject: [PATCH 32/70] BugFix: remove select for sample size > 1 which was an unnecessary implicit assumption --- .../bionemo/contrib/data/molecule/diffdock/datamodule.py | 3 --- .../src/bionemo/contrib/data/molecule/diffdock/utils.py | 8 ++++++++ .../tests/bionemo/contrib/data/conftest.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 3504a6868f..96ab30d81b 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -211,9 +211,6 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: loader = loader.compose( self._pipeline_prebatch_wld[split]) - if split == Split.train: - loader = loader.select(lambda x: len(x) > 1) - loader = loader.with_epoch(n_batches) # strange features required by nemo optimizer lr_scheduler diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py index 9ca329d596..3fc21b769d 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py @@ -181,11 +181,13 @@ def __init__( max_total_size: int, size_fn: Callable[[HeteroData], int], collate_fn: Callable[[List[Any]], Any] = Collater(dataset=None, follow_batch=None, exclude_keys=None), + no_single_sample : bool = True ): self.max_total_size = max_total_size self.size_fn = size_fn self.collate_fn = collate_fn self.cached_sizes = {} + self.no_single_sample = no_single_sample def __call__(self, data: Batch) -> Generator[Batch, None, None]: batch_size = 0 @@ -202,6 +204,12 @@ def __call__(self, data: Batch) -> Generator[Batch, None, None]: batch.append(sample) batch_size += sample_size else: + if self.no_single_sample and len(batch) <= 1: + # memory size requirement is met but there is less than 2 + # samples in the batch so skip + batch = [sample] + batch_size = sample_size + continue if self.collate_fn is not None: batch = self.collate_fn(batch) yield batch diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 49a37c92c2..5f50804ade 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -97,7 +97,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, pipelines_wdl_batch = { Split.train : SizeAwareBatching( max_total_size=size_cuda_mem, - size_fn=estimate_size), + size_fn=estimate_size, no_single_sample=True), Split.val : batch_pyg, Split.test : batch_pyg, } From 975c91e1754a4b973571a73500e96d92d83ccc75 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 10 Aug 2024 00:46:47 +0000 Subject: [PATCH 33/70] Enhancement: remove legacy dataloader config ... of setting dataset object and batch_size etc --- .../src/bionemo/contrib/data/molecule/diffdock/datamodule.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index 96ab30d81b..b863dd3b39 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -213,10 +213,6 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: loader = loader.with_epoch(n_batches) - # strange features required by nemo optimizer lr_scheduler - loader.dataset = dataset # seems like only length is used, webloader doesn't have this attr - loader.batch_size = self._global_batch_size - loader.drop_last = False return loader From a3b07e27e509f9cc131723ee08ddbbb6f03dd87a Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 10 Aug 2024 01:35:33 +0000 Subject: [PATCH 34/70] Enhancement: allow input an iterable of generators passed as pipelines webdataset's compose() already support taking a *args of pipelines --- .../data/molecule/diffdock/datamodule.py | 56 +++++++++++-------- .../tests/bionemo/contrib/data/conftest.py | 11 +++- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py index b863dd3b39..15a922ba47 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py @@ -16,7 +16,7 @@ from enum import Enum, auto import glob import random -from typing import Any, Dict, Generator, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import lightning as L import webdataset as wds @@ -44,11 +44,12 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, int], suffix_keys_wds : Iterable[str], global_batch_size : int, prefix_tars_wds : str = "wdshards", - pipeline_wds : Optional[Dict[Split, Generator[Any, None, - None]]] = None, - pipeline_prebatch_wld : Optional[Dict[Split, Generator[Any, - None, - None]]] + pipeline_wds : Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = + None, pipeline_prebatch_wld : Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, seed_rng_shfl : int = 0, kwargs_dl : Optional[Dict[Split, Dict[str, str]]] = None): """constructor @@ -74,19 +75,22 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, prefix_tars_wds (str): name prefix of the input webdataset tar files. The input tar files are globbed by "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" - pipeline_wds (Optional[Dict[Split, Generator[Any, None, None]]]): a - dictionary of webdatast composable, i.e., functor that maps a - generator to another generator that transforms the data sample - yield from the dataset object, for different splits. For - example, this can be used to transform the sample in the worker - before sending it to the main process of the dataloader - pipeline_prebatch_wld (Optional[Dict[Split, Generator[Any, None, - None]]]): a dictionary of webloader composable, i.e., functor - that maps a generator to another generator that transforms the - data sample yield from the WebLoader object, for different - splits. For example, this can be used for batching the samples. - NOTE: this is applied before batching is yield from the - WebLoader + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]]): a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader seed_rng_shfl (int): seed to the random number generators used in data loading time for shuffling kwargs_dl (Optional[Dict[Split, Dict[str, str]]]): kwargs for data @@ -157,7 +161,11 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: ) if (self._pipeline_wds is not None and self._pipeline_wds[split] is not None): - dataset = dataset.compose(self._pipeline_wds[split]) + if isinstance(self._pipeline_wds[split], + Iterable): + dataset = dataset.compose(*self._pipeline_wds[split]) + else: + dataset = dataset.compose(self._pipeline_wds[split]) if is_train: dataset = dataset.shuffle(size=16, rng=random.Random(self._seed_rng_shfl)) @@ -208,8 +216,12 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: if (self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None): - loader = loader.compose( - self._pipeline_prebatch_wld[split]) + if isinstance(self._pipeline_prebatch_wld[split], Iterable): + loader = loader.compose( + *self._pipeline_prebatch_wld[split]) + else: + loader = loader.compose( + self._pipeline_prebatch_wld[split]) loader = loader.with_epoch(n_batches) diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py index 5f50804ade..1690e81c49 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py @@ -15,6 +15,7 @@ from enum import Enum, auto import os +from typing import Any, Iterable import pytest from functools import partial import torch @@ -64,6 +65,10 @@ def get_diffdock_heterodata(get_path, request): return (dir_heterodata, suffix_heterodata, names, model) +def no_op_gen(it : Iterable[Any]): + yield from it + + def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): prefix_dir_tars_wds = tmp_path_factory.mktemp( @@ -78,8 +83,8 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, tor_sigma_min, tor_sigma_max) # webdataset pipeline generateNoise = { - Split.train : GenerateNoise(sigma_t, no_torsion, is_all_atom, - copy_ref_pos=False), + Split.train : [GenerateNoise(sigma_t, no_torsion, is_all_atom, + copy_ref_pos=False), no_op_gen], Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=True), Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, @@ -98,7 +103,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, Split.train : SizeAwareBatching( max_total_size=size_cuda_mem, size_fn=estimate_size, no_single_sample=True), - Split.val : batch_pyg, + Split.val : [batch_pyg, no_op_gen], Split.test : batch_pyg, } n_tars_wds = 4 From 2c8a2c8db43dceefcd09032bc5261ff30b38b242 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 13 Aug 2024 18:22:07 +0000 Subject: [PATCH 35/70] Enhancement: move the webdataset data module into bionemo-core ... and move the diffdock routines into bionemo-diffdock --- .../data/molecule/diffdock/__init__.py | 14 - .../tests/bionemo/contrib/data/conftest.py | 241 ------------------ .../src/bionemo/core/data}/datamodule.py | 144 +++++------ .../src/bionemo/core/data/utils.py | 108 ++++++++ sub-packages/bionemo-diffdock/LICENSE | 202 +++++++++++++++ sub-packages/bionemo-diffdock/README.md | 6 + .../bionemo-diffdock/_requirements-test.txt | 1 + .../bionemo-diffdock/_requirements.txt | 1 + sub-packages/bionemo-diffdock/pyproject.toml | 26 ++ .../bionemo-diffdock/requirements.txt | 2 + .../src/bionemo/diffdock/utils/data.py} | 124 +-------- .../src/bionemo}/diffdock/utils/diffusion.py | 42 ++- .../src/bionemo}/diffdock/utils/geometry.py | 16 ++ .../src/bionemo}/diffdock/utils/so3.py | 16 ++ .../src/bionemo}/diffdock/utils/torsion.py | 16 ++ .../src/bionemo}/diffdock/utils/torus.py | 16 ++ .../tests/bionemo/diffdock/data/conftest.py | 232 +++++++++++++++++ .../data/test_diffdock_datamodule.py | 129 +++++----- 18 files changed, 811 insertions(+), 525 deletions(-) delete mode 100644 sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py delete mode 100644 sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py rename sub-packages/{bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock => bionemo-core/src/bionemo/core/data}/datamodule.py (73%) create mode 100644 sub-packages/bionemo-core/src/bionemo/core/data/utils.py create mode 100644 sub-packages/bionemo-diffdock/LICENSE create mode 100644 sub-packages/bionemo-diffdock/README.md create mode 100644 sub-packages/bionemo-diffdock/_requirements-test.txt create mode 100644 sub-packages/bionemo-diffdock/_requirements.txt create mode 100644 sub-packages/bionemo-diffdock/pyproject.toml create mode 100644 sub-packages/bionemo-diffdock/requirements.txt rename sub-packages/{bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py => bionemo-diffdock/src/bionemo/diffdock/utils/data.py} (65%) rename sub-packages/{bionemo-contrib/src/bionemo/contrib/model/molecule => bionemo-diffdock/src/bionemo}/diffdock/utils/diffusion.py (85%) rename sub-packages/{bionemo-contrib/src/bionemo/contrib/model/molecule => bionemo-diffdock/src/bionemo}/diffdock/utils/geometry.py (86%) rename sub-packages/{bionemo-contrib/src/bionemo/contrib/model/molecule => bionemo-diffdock/src/bionemo}/diffdock/utils/so3.py (90%) rename sub-packages/{bionemo-contrib/src/bionemo/contrib/model/molecule => bionemo-diffdock/src/bionemo}/diffdock/utils/torsion.py (85%) rename sub-packages/{bionemo-contrib/src/bionemo/contrib/model/molecule => bionemo-diffdock/src/bionemo}/diffdock/utils/torus.py (84%) create mode 100644 sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py rename sub-packages/{bionemo-contrib/tests/bionemo/contrib => bionemo-diffdock/tests/bionemo/diffdock}/data/test_diffdock_datamodule.py (62%) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py b/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py deleted file mode 100644 index 25e6abfbc5..0000000000 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py b/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py deleted file mode 100644 index 1690e81c49..0000000000 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/conftest.py +++ /dev/null @@ -1,241 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum, auto -import os -from typing import Any, Iterable -import pytest -from functools import partial -import torch -from torch_geometric.loader.dataloader import Collater -from webdataset.filters import batched -import lightning as L - -from bionemo.contrib.model.molecule.diffdock.utils.diffusion import ( - t_to_sigma, GenerateNoise) -from bionemo.contrib.data.molecule.diffdock.utils import ( - SizeAwareBatching, estimate_size, SelectPoseAndLabelData - ) -from bionemo.contrib.data.molecule.diffdock.datamodule import ( - Split, PickledDataWDS - ) - - -@pytest.fixture(scope="module") -def get_path(request): - dir_test = os.path.dirname(request.module.__file__) - dir_data = f"{dir_test}/test_data" - return dir_test, dir_data - - -class DiffDockModel(Enum): - score = auto() - confidence = auto() - - -@pytest.fixture(scope="module", params=[m for m in DiffDockModel]) -def get_diffdock_heterodata(get_path, request): - _, dir_data = get_path - model = request.param - name_model = str(model).split(".")[-1] - dir_heterodata =\ - f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/{name_model}_model" - suffix_heterodata = "heterodata.pyd" - names = { - Split.train : ["6t88", "6vs3", "6wtn", "6yqv", "7amc", "7bmi", "7cuo", - "7d5c", "7din", "7fha", "7jnb", "7k0v", "7kb1", "7km8", - "7l7c", "7lcu", "7msr", "7my1", "7n6f", "7np6"], - Split.val : ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", - "7rh3", "7rzl", "7sgv"], - Split.test : ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", - "7wpw", "7xek", "7xij"] - } - return (dir_heterodata, suffix_heterodata, names, model) - - -def no_op_gen(it : Iterable[Any]): - yield from it - - -def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, - suffix_heterodata, names): - prefix_dir_tars_wds = tmp_path_factory.mktemp( - f"diffdock_score_model_tars_wds").as_posix() - tr_sigma_min, tr_sigma_max = (0.1, 19) - rot_sigma_min, rot_sigma_max = (0.03, 1.55) - tor_sigma_min, tor_sigma_max = (0.0314, 3.14) - is_all_atom = False - no_torsion = False - sigma_t = partial(t_to_sigma, tr_sigma_min, - tr_sigma_max, rot_sigma_min, rot_sigma_max, - tor_sigma_min, tor_sigma_max) - # webdataset pipeline - generateNoise = { - Split.train : [GenerateNoise(sigma_t, no_torsion, is_all_atom, - copy_ref_pos=False), no_op_gen], - Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, - copy_ref_pos=True), - Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, - copy_ref_pos=False), - } - local_batch_size = 2 - global_batch_size = 2 - size_cuda_mem = (0.85 * - torch.cuda.get_device_properties("cuda:0").total_memory / - 2**20) - batch_pyg = batched(local_batch_size, - collation_fn=Collater(dataset=[], follow_batch=None, - exclude_keys=None)) - # WebLoader pipeline - pipelines_wdl_batch = { - Split.train : SizeAwareBatching( - max_total_size=size_cuda_mem, - size_fn=estimate_size, no_single_sample=True), - Split.val : [batch_pyg, no_op_gen], - Split.test : batch_pyg, - } - n_tars_wds = 4 - seed_rng_shfl = 822782392 - kwargs_dl = { - Split.train : {'num_workers' : 2}, - Split.val : {'num_workers' : 2}, - Split.test : {'num_workers' : 2}, - } - data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, - names, prefix_dir_tars_wds, global_batch_size, - n_tars_wds=n_tars_wds, - prefix_tars_wds="heterographs", - pipeline_wds=generateNoise, - pipeline_prebatch_wld=pipelines_wdl_batch, - seed_rng_shfl=seed_rng_shfl, - kwargs_dl=kwargs_dl - ) - return data_module, prefix_dir_tars_wds - - -def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, - suffix_heterodata, names): - prefix_dir_tars_wds = tmp_path_factory.mktemp( - "diffdock_confidence_model_tars_wds").as_posix() - # webdataset pipeline - rmsd_classification_cutoff = 2.0 - samples_per_complex = 7 - balance = False - is_all_atom = True - seed_rng_shfl = 822782392 - select_pose = SelectPoseAndLabelData(rmsd_classification_cutoff, - samples_per_complex, balance, - is_all_atom, seed=seed_rng_shfl) - pipeline_wds = { - Split.train : select_pose, - Split.val : select_pose, - Split.test : select_pose, - } - local_batch_size = 2 - global_batch_size = 2 - batch_pyg = batched(local_batch_size, - collation_fn=Collater(dataset=[], follow_batch=None, - exclude_keys=None)) - # WebLoader pipeline - pipelines_wdl_batch = { - Split.train : batch_pyg, - Split.val : batch_pyg, - Split.test : batch_pyg, - } - n_tars_wds = 4 - kwargs_dl = { - Split.train : {'num_workers' : 2}, - Split.val : {'num_workers' : 2}, - Split.test : {'num_workers' : 2}, - } - data_module = PickledDataWDS(dir_heterodata, suffix_heterodata, - names, prefix_dir_tars_wds, global_batch_size, - n_tars_wds=n_tars_wds, - prefix_tars_wds="heterographs", - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipelines_wdl_batch, - seed_rng_shfl=seed_rng_shfl, - kwargs_dl=kwargs_dl - ) - return data_module, prefix_dir_tars_wds - - -@pytest.fixture(scope="module") -def create_datamodule(tmp_path_factory, get_diffdock_heterodata): - dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata - if model == DiffDockModel.score: - return _create_datamodule_score_model_impl(tmp_path_factory, - dir_heterodata, - suffix_heterodata, - names) - elif model == DiffDockModel.confidence: - return _create_datamodule_confidence_model_impl(tmp_path_factory, - dir_heterodata, - suffix_heterodata, - names) - - -@pytest.fixture(scope="module") -def create_another_datamodule(tmp_path_factory, get_diffdock_heterodata): - dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata - if model == DiffDockModel.score: - return _create_datamodule_score_model_impl(tmp_path_factory, - dir_heterodata, - suffix_heterodata, - names) - elif model == DiffDockModel.confidence: - return _create_datamodule_confidence_model_impl(tmp_path_factory, - dir_heterodata, - suffix_heterodata, - names) - - -class ModelTestDiffDock(L.LightningModule): - def __init__(self) -> None: - super().__init__() - self._model = torch.nn.Linear(3, 3) - self._samples = { split : [] for split in Split } - - - def forward(self, x): - return self._model(x["ligand"].pos) - - def training_step(self, batch): - self._samples[Split.train].append(batch.name) - loss = self(batch).sum() - return loss - - def validation_step(self, batch, batch_index): - self._samples[Split.val].append(batch.name) - return torch.zeros(1) - - def test_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) - - def predict_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) - return torch.zeros(1) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) - return optimizer - - -@pytest.fixture(scope="function") -def create_trainer_and_model(): - trainer = L.Trainer(max_epochs=1, accelerator="gpu", - devices=1, val_check_interval=1) - model = ModelTestDiffDock() - return trainer, model diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py similarity index 73% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py rename to sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index 15a922ba47..5c66173fc2 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum, auto import glob import random +from enum import Enum, auto from typing import Any, Dict, Iterable, List, Optional, Union + import lightning as L import webdataset as wds -from bionemo.contrib.data.molecule.diffdock.utils import ( - pickles_to_tars - ) +from bionemo.core.data.utils import pickles_to_tars class Split(Enum): @@ -32,7 +31,6 @@ class Split(Enum): class WDSModule(L.LightningDataModule): - """lightning data module for using webdataset tar files to setup dataset and dataloader. This data module takes a dictionary: Split -> tar file directory. In its setup() function, it creates the webdataset object @@ -40,18 +38,18 @@ class WDSModule(L.LightningDataModule): train/val/test_dataloader(), it creates the WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" - def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, - int], - suffix_keys_wds : Iterable[str], global_batch_size : int, - prefix_tars_wds : str = "wdshards", - pipeline_wds : Optional[Dict[Split, - Union[Iterable[Iterable[Any]], - Iterable[Any]]]] = - None, pipeline_prebatch_wld : Optional[Dict[Split, - Union[Iterable[Iterable[Any]], - Iterable[Any]]]] - = None, seed_rng_shfl : int = 0, - kwargs_dl : Optional[Dict[Split, Dict[str, str]]] = None): + def __init__( + self, + dirs_tars_wds: Dict[Split, str], + n_samples: Dict[Split, int], + suffix_keys_wds: Iterable[str], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, + pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, + seed_rng_shfl: int = 0, + kwargs_dl: Optional[Dict[Split, Dict[str, str]]] = None, + ): """constructor Args: @@ -105,12 +103,11 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, keys_subset = self._dirs_tars_wds.keys() if n_samples.keys() != keys_subset: - raise RuntimeError(f"Input n_samples has different keys than " - f"dirs_tars_wds: {n_samples.keys()} vs " - f"{keys_subset}" - ) + raise RuntimeError( + f"Input n_samples has different keys than " f"dirs_tars_wds: {n_samples.keys()} vs " f"{keys_subset}" + ) - self._n_samples= n_samples + self._n_samples = n_samples self._global_batch_size = global_batch_size self._suffix_keys_wds = suffix_keys_wds @@ -124,8 +121,7 @@ def __init__(self, dirs_tars_wds : Dict[Split, str], n_samples : Dict[Split, self._kwargs_dl = kwargs_dl # to be created later in setup - self._dataset = dict() - + self._dataset = {} def prepare_data(self) -> None: """This is called only by the main process by the Lightning workflow. Do @@ -136,7 +132,7 @@ def prepare_data(self) -> None: """ pass - def _setup_wds(self, split : Split) -> wds.WebDataset: + def _setup_wds(self, split: Split) -> wds.WebDataset: """setup webdataset and webloader. This is called by setup() Args: @@ -145,30 +141,22 @@ def _setup_wds(self, split : Split) -> wds.WebDataset: Returns: WebDataset """ - if not split in self._dirs_tars_wds.keys(): - raise RuntimeError(f"_setup_wds() is called with {split} " - f"split that doesn't have the input tar dir") + if split not in self._dirs_tars_wds.keys(): + raise RuntimeError(f"_setup_wds() is called with {split} " f"split that doesn't have the input tar dir") is_train = split == Split.train - urls = sorted(glob.glob( - f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") - ) + urls = sorted(glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar")) dataset = ( - wds.WebDataset(urls, shardshuffle=is_train, - nodesplitter=wds.split_by_node, - seed=self._seed_rng_shfl) + wds.WebDataset(urls, shardshuffle=is_train, nodesplitter=wds.split_by_node, seed=self._seed_rng_shfl) .decode() .extract_keys(f"*.{self._suffix_keys_wds}") - ) - if (self._pipeline_wds is not None and - self._pipeline_wds[split] is not None): - if isinstance(self._pipeline_wds[split], - Iterable): + ) + if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: + if isinstance(self._pipeline_wds[split], Iterable): dataset = dataset.compose(*self._pipeline_wds[split]) else: dataset = dataset.compose(self._pipeline_wds[split]) if is_train: - dataset = dataset.shuffle(size=16, - rng=random.Random(self._seed_rng_shfl)) + dataset = dataset.shuffle(size=16, rng=random.Random(self._seed_rng_shfl)) return dataset def setup(self, stage: str) -> None: @@ -190,10 +178,9 @@ def setup(self, stage: str) -> None: elif stage == "predict": self._dataset[Split.test] = self._setup_wds(Split.test) else: - raise NotImplementedError(f"Data setup with stage = {stage} "\ - f"is not implmented") + raise NotImplementedError(f"Data setup with stage = {stage} " f"is not implmented") - def _setup_dataloader(self, split : Split) -> wds.WebLoader: + def _setup_dataloader(self, split: Split) -> wds.WebLoader: """setup the dataloader for the input dataset split Args: @@ -203,49 +190,41 @@ def _setup_dataloader(self, split : Split) -> wds.WebLoader: """ if self._dataset[split] is None: - raise RuntimeError(f"_setup_dataloader() is called with {split} " - f"split without setting up the corresp. dataset") + raise RuntimeError( + f"_setup_dataloader() is called with {split} " f"split without setting up the corresp. dataset" + ) dataset = self._dataset[split] n_samples = self._n_samples[split] - n_batches = ((n_samples + self._global_batch_size - 1) - // self._global_batch_size) + n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size kwargs = self._kwargs_dl[split] if self._kwargs_dl is not None else None - loader = wds.WebLoader(dataset, batch_size=None, - **(kwargs if kwargs is not None else {}) - ).shuffle(5000, rng=random.Random(self._seed_rng_shfl)) + loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {})).shuffle( + 5000, rng=random.Random(self._seed_rng_shfl) + ) - if (self._pipeline_prebatch_wld is not None and - self._pipeline_prebatch_wld[split] is not None): + if self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None: if isinstance(self._pipeline_prebatch_wld[split], Iterable): - loader = loader.compose( - *self._pipeline_prebatch_wld[split]) + loader = loader.compose(*self._pipeline_prebatch_wld[split]) else: - loader = loader.compose( - self._pipeline_prebatch_wld[split]) + loader = loader.compose(self._pipeline_prebatch_wld[split]) loader = loader.with_epoch(n_batches) return loader - def train_dataloader(self) -> wds.WebLoader: return self._setup_dataloader(Split.train) - def val_dataloader(self) -> wds.WebLoader: return self._setup_dataloader(Split.val) - def test_dataloader(self) -> wds.WebLoader: return self._setup_dataloader(Split.test) - def predict_dataloader(self) -> wds.WebLoader: return self._setup_dataloader(Split.test) class PickledDataWDS(WDSModule): - """lightning APIs to process pickled data into webdataset tar files and setup dataset and dataloader. This data module takes a directory of pickled data files, data filename prefixes for train/val/test splits, data filename @@ -258,9 +237,16 @@ class PickledDataWDS(WDSModule): `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" - def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : - Dict[Split, List[str]], prefix_dir_tars_wds : str, *args, - n_tars_wds : Optional[int] = None, **kwargs): + def __init__( + self, + dir_pickled: str, + suffix_pickled: str, + names_subset: Dict[Split, List[str]], + prefix_dir_tars_wds: str, + *args, + n_tars_wds: Optional[int] = None, + **kwargs, + ): """constructor Args: @@ -284,18 +270,12 @@ def __init__(self, dir_pickled : str, suffix_pickled : str, names_subset : """ super().__init__( - { - split : f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" - for split in names_subset.keys() - }, - { - split : len(names_subset[split]) for split in - names_subset.keys() - }, + {split: f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" for split in names_subset.keys()}, + {split: len(names_subset[split]) for split in names_subset.keys()}, suffix_pickled, *args, - **kwargs - ) + **kwargs, + ) self._dir_pickled = dir_pickled self._suffix_pickled = suffix_pickled @@ -318,9 +298,11 @@ def prepare_data(self) -> None: """ for split in self._names_subset.keys(): # create wds shards (tar files) for train set - pickles_to_tars(self._dir_pickled, - self._suffix_pickled, - self._names_subset[split], - self._dirs_tars_wds[split], - self._prefix_tars_wds, - min_num_shards=self._n_tars_wds) + pickles_to_tars( + self._dir_pickled, + self._suffix_pickled, + self._names_subset[split], + self._dirs_tars_wds[split], + self._prefix_tars_wds, + min_num_shards=self._n_tars_wds, + ) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/utils.py b/sub-packages/bionemo-core/src/bionemo/core/data/utils.py new file mode 100644 index 0000000000..ab4ed8d3cb --- /dev/null +++ b/sub-packages/bionemo-core/src/bionemo/core/data/utils.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +from typing import Any, Callable, Dict, List, Optional + +import webdataset as wds +from nemo.utils import logging + + +def pickles_to_tars( + dir_input: str, + input_suffix: str, + input_prefix_subset: List[str], + dir_output: str, + output_prefix: str, + func_output_data: Callable[[str, str, Any], Dict[str, Any]] = lambda prefix, suffix, data: { + "__key__": prefix, + suffix: pickle.dumps(data), + }, + min_num_shards: Optional[int] = None, +) -> None: + """Convert a subset of pickle files from a directory to Webdataset tar files + Input path and name pattern: + f"{dir_input}/{input_prefix_subset}.{input_suffix}" + Output path and name pattern: + f"{dir_output}/{output_prefix}-%06d.tar" + + The webdataset tar archive is specified by the dictionary: + { + "__key__" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... + } + so that parsing the tar archive is equivalent of reading + {sample_filename_preifx}.{sample_filename_suffix_1} etc. + + Here, the assumption is that there is only one sample data file, whose name + prefix is given in each of the elements of `input_prefix_subset` and whose + name suffix is given by `input_suffix`. Per the webdataset file format + specification, the `sample_filename_preifx` can't contain dots '.' so this + function removes it for the user by calling .replace(".", "-") on the + elements of `input_prefix_subset` + + Args: + dir_input (str): Input directory + input_suffix (str): Input pickle file name suffix + input_prefix_subset (List[str]): Input subset of pickle files' prefix + dir_output (str): Output directory + output_prefix (str): Output tar file name prefix + func_output_data (Callable[[str, str, Any], Dict[str, Any]]) : function + that maps the name prefix, name suffix and data object to a + webdataset tar archive dictionary. Refer to the webdataset github + repo for the archive file format specification. + min_num_shards (int) : create at least this number of tar files. + WebDataset has bugs when reading small number of tar files in a + multi-node lightening + DDP setting so this option can be used to + guarantee the tar file counts + + Returns: None + + """ + os.makedirs(dir_output, exist_ok=True) + wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") + maxsize = 1e8 + # Due to a Webdataset bug, number of shards should be >= number of workers + # (num. of gpus * num. of workers per gpu) + # TODO: this algorithm is not accurate enough because it doesn't take into + # account the block structure so I have to multiply the total_size with a + # small prefactor to purposely underestimate the size so that it ends up + # creating more tar files than min_num_shards + if min_num_shards is not None and min_num_shards > 1: + total_size = 0 + for name in input_prefix_subset: + try: + total_size += os.stat(os.path.join(dir_input, f"{name}.{input_suffix}")).st_size + except Exception: + continue + maxsize = min(total_size * 0.6 // min_num_shards, maxsize) + with wds.ShardWriter(wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777) as sink: + for name in input_prefix_subset: + try: + data = pickle.load(open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb")) + # the prefix name shouldn't contain any "." per webdataset's + # specification + sample = func_output_data(name.replace(".", "-"), input_suffix, data) + except ModuleNotFoundError as e: + logging.error(f"Dependency for parsing input pickle data not " f"found: {e}") + raise e + except Exception as e: + logging.error(f"Failed to write {name} into tar files due to error {e}") + raise e + + sink.write(sample) diff --git a/sub-packages/bionemo-diffdock/LICENSE b/sub-packages/bionemo-diffdock/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/sub-packages/bionemo-diffdock/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sub-packages/bionemo-diffdock/README.md b/sub-packages/bionemo-diffdock/README.md new file mode 100644 index 0000000000..938640398b --- /dev/null +++ b/sub-packages/bionemo-diffdock/README.md @@ -0,0 +1,6 @@ +# bionemo-diffdock + + +```bash +pip install -e . +``` diff --git a/sub-packages/bionemo-diffdock/_requirements-test.txt b/sub-packages/bionemo-diffdock/_requirements-test.txt new file mode 100644 index 0000000000..47d98580a4 --- /dev/null +++ b/sub-packages/bionemo-diffdock/_requirements-test.txt @@ -0,0 +1 @@ +-e ../bionemo-testing diff --git a/sub-packages/bionemo-diffdock/_requirements.txt b/sub-packages/bionemo-diffdock/_requirements.txt new file mode 100644 index 0000000000..22c9b13a61 --- /dev/null +++ b/sub-packages/bionemo-diffdock/_requirements.txt @@ -0,0 +1 @@ +-e ../bionemo-core diff --git a/sub-packages/bionemo-diffdock/pyproject.toml b/sub-packages/bionemo-diffdock/pyproject.toml new file mode 100644 index 0000000000..29057f478c --- /dev/null +++ b/sub-packages/bionemo-diffdock/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "bionemo-diffdock" +version = "2.0.0" +license.file = "LICENSE" +readme = "README.md" +description = "BioNeMo DiffDock" +authors = [ + { name = "John St. John", email = "jstjohn@nvidia.com" }, + { name = "Malcolm Greaves", email = "mgreaves@nvidia.com" }, + { name = "Dejun Lin", email = "dejun.lin@gmail.com" }, +] +dynamic = ["dependencies"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} +# TODO: how to specify bionemo-{feature packages} & bionemo-core ??? + +[tool.setuptools.packages.find] +where = ["src"] +include=["bionemo.*"] +namespaces = true +exclude = ["test*."] diff --git a/sub-packages/bionemo-diffdock/requirements.txt b/sub-packages/bionemo-diffdock/requirements.txt new file mode 100644 index 0000000000..ec114b4f89 --- /dev/null +++ b/sub-packages/bionemo-diffdock/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.26.4 +scipy==1.12.0 diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py similarity index 65% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py index 3fc21b769d..13116253d4 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/data/molecule/diffdock/utils.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py @@ -12,118 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import os -import pickle -import random import math -from typing import ( - Any, Dict, Callable, Generator, List, Optional,Union, Iterable - ) +import random +from typing import Any, Callable, Generator, Iterable, List, Union -from omegaconf.listconfig import ListConfig -from nemo.utils import logging +import numpy as np import torch +from nemo.utils import logging +from omegaconf.listconfig import ListConfig from torch_geometric.data import HeteroData from torch_geometric.data.batch import Batch from torch_geometric.loader.dataloader import Collater -import numpy as np - -import webdataset as wds - - -def pickles_to_tars( - dir_input: str, - input_suffix: str, - input_prefix_subset: List[str], - dir_output: str, - output_prefix: str, - func_output_data: Callable[[str, str, Any], Dict[str, Any]] = - lambda prefix, suffix, data: { "__key__": prefix, - suffix: pickle.dumps(data) }, - min_num_shards: Optional[int] = None, -) -> None: - """Convert a subset of pickle files from a directory to Webdataset tar files - Input path and name pattern: - f"{dir_input}/{input_prefix_subset}.{input_suffix}" - Output path and name pattern: - f"{dir_output}/{output_prefix}-%06d.tar" - - The webdataset tar archive is specified by the dictionary: - { - "__key__" : sample_filename_preifx, - sample_filename_suffix_1 : data_1, - sample_filename_suffix_2 : data_2, - ... - } - so that parsing the tar archive is equivalent of reading - {sample_filename_preifx}.{sample_filename_suffix_1} etc. - - Here, the assumption is that there is only one sample data file, whose name - prefix is given in each of the elements of `input_prefix_subset` and whose - name suffix is given by `input_suffix`. Per the webdataset file format - specification, the `sample_filename_preifx` can't contain dots '.' so this - function removes it for the user by calling .replace(".", "-") on the - elements of `input_prefix_subset` - - Args: - dir_input (str): Input directory - input_suffix (str): Input pickle file name suffix - input_prefix_subset (List[str]): Input subset of pickle files' prefix - dir_output (str): Output directory - output_prefix (str): Output tar file name prefix - func_output_data (Callable[[str, str, Any], Dict[str, Any]]) : function - that maps the name prefix, name suffix and data object to a - webdataset tar archive dictionary. Refer to the webdataset github - repo for the archive file format specification. - min_num_shards (int) : create at least this number of tar files. - WebDataset has bugs when reading small number of tar files in a - multi-node lightening + DDP setting so this option can be used to - guarantee the tar file counts - - Returns: None - - """ - os.makedirs(dir_output, exist_ok=True) - wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") - maxsize = 1e8 - # Due to a Webdataset bug, number of shards should be >= number of workers - # (num. of gpus * num. of workers per gpu) - # TODO: this algorithm is not accurate enough because it doesn't take into - # account the block structure so I have to multiply the total_size with a - # small prefactor to purposely underestimate the size so that it ends up - # creating more tar files than min_num_shards - if min_num_shards is not None and min_num_shards > 1: - total_size = 0 - for name in input_prefix_subset: - try: - total_size += os.stat(os.path.join(dir_input, f"{name}.{input_suffix}")).st_size - except Exception: - continue - maxsize = min(total_size * 0.6 // min_num_shards, maxsize) - with wds.ShardWriter(wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777) as sink: - for name in input_prefix_subset: - try: - data = pickle.load(open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb")) - # the prefix name shouldn't contain any "." per webdataset's - # specification - sample = func_output_data(name.replace(".", "-"), - input_suffix, data) - except ModuleNotFoundError as e: - logging.error(f"Dependency for parsing input pickle data not "\ - f"found: {e}") - raise e - except Exception as e: - logging.error(f"Failed to write {name} into tar files due to error {e}") - raise e - - sink.write(sample) def num_cross_edge_upper_bound_estimate(n1, n2, n3, n4): - terms = [[4.92, 'ligand_ligand'], - [0.0118, 'receptor_receptor'], - [0.0401, 'ligand', 'receptor_receptor']] + terms = [[4.92, "ligand_ligand"], [0.0118, "receptor_receptor"], [0.0401, "ligand", "receptor_receptor"]] scale = 1.03 tmpdict = {"ligand": n1, "ligand_ligand": n2, "receptor": n3, "receptor_receptor": n4} num_edges = 0.0 @@ -156,6 +59,7 @@ def estimate_memory_usage(data, num_cross_edges, use_bias=True): else: return total_memory + def estimate_size(g): n1, n2, n3, n4 = ( g["ligand"].num_nodes, @@ -168,8 +72,7 @@ def estimate_size(g): # the empirical formula here is from the polynomial fitting # the scaling constant is to help remove the outliers above the upper bound estimation. n5 = num_cross_edge_upper_bound_estimate(n1, n2, n3, n4) - total_memory = estimate_memory_usage(g, n5, - use_bias=False) + total_memory = estimate_memory_usage(g, n5, use_bias=False) return total_memory @@ -180,8 +83,8 @@ def __init__( self, max_total_size: int, size_fn: Callable[[HeteroData], int], - collate_fn: Callable[[List[Any]], Any] = Collater(dataset=None, follow_batch=None, exclude_keys=None), - no_single_sample : bool = True + collate_fn: Callable[[List[Any]], Any] = Collater(dataset=[], follow_batch=None, exclude_keys=None), + no_single_sample: bool = True, ): self.max_total_size = max_total_size self.size_fn = size_fn @@ -189,7 +92,7 @@ def __init__( self.cached_sizes = {} self.no_single_sample = no_single_sample - def __call__(self, data: Batch) -> Generator[Batch, None, None]: + def __call__(self, data: Batch) -> Generator[Union[Batch, List[HeteroData]], None, None]: batch_size = 0 batch = [] @@ -228,7 +131,7 @@ def __init__( samples_per_complex: int, balance: bool, all_atoms: bool, - seed : int = 0 + seed: int = 0, ): """constructor @@ -283,8 +186,7 @@ def __call__(self, data: Iterable) -> Generator[HeteroData, None, None]: else: sample = random.randint(0, self.samples_per_complex - 1) complex_graph["ligand"].pos = torch.from_numpy(positions[sample]) - ids = (rmsds[sample] < - self.rmsd_classification_cutoff).astype(int) + ids = (rmsds[sample] < self.rmsd_classification_cutoff).astype(int) complex_graph.y = torch.tensor(ids).float().unsqueeze(0) if isinstance(self.rmsd_classification_cutoff, ListConfig): complex_graph.y_binned = torch.tensor( diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py similarity index 85% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py index 79b42a75f9..4c3e361ec1 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # @@ -19,14 +35,14 @@ from torch import nn from torch_geometric.data.hetero_data import HeteroData -from bionemo.contrib.model.molecule.diffdock.utils import so3, torus -from bionemo.contrib.model.molecule.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch -from bionemo.contrib.model.molecule.diffdock.utils.torsion import modify_conformer_torsion_angles +from bionemo.diffdock.utils import so3, torus +from bionemo.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch +from bionemo.diffdock.utils.torsion import modify_conformer_torsion_angles -def t_to_sigma(tr_sigma_min, tr_sigma_max, rot_sigma_min, - rot_sigma_max, tor_sigma_min, tor_sigma_max, - t_tr, t_rot, t_tor): +def t_to_sigma( + tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max, t_tr, t_rot, t_tor +): tr_sigma = tr_sigma_min ** (1 - t_tr) * tr_sigma_max**t_tr rot_sigma = rot_sigma_min ** (1 - t_rot) * rot_sigma_max**t_rot tor_sigma = tor_sigma_min ** (1 - t_tor) * tor_sigma_max**t_tor @@ -163,17 +179,19 @@ class GenerateNoise: copy_ref_pos (bool): whether or not make a copy of the input ligand position """ - def __init__(self, t_to_sigma: Callable[[float, float, float], Tuple[float, - float, - float]], - no_torsion: bool, all_atom: bool, copy_ref_pos: bool = False): + def __init__( + self, + t_to_sigma: Callable[[float, float, float], Tuple[float, float, float]], + no_torsion: bool, + all_atom: bool, + copy_ref_pos: bool = False, + ): self.t_to_sigma = t_to_sigma self.no_torsion = no_torsion self.all_atom = all_atom self._copy_ref_pos = copy_ref_pos - def __call__(self, source : Generator[HeteroData, None, None]) \ - -> Generator[HeteroData, None, None]: + def __call__(self, source: Generator[HeteroData, None, None]) -> Generator[HeteroData, None, None]: for (data,) in source: if self._copy_ref_pos: data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py similarity index 86% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py index 648045d261..547b17c1c7 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/geometry.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py similarity index 90% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py index 7fa0da0a3e..651828ba81 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/so3.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py similarity index 85% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py index b4c7e3692f..003e028241 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torsion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py similarity index 84% rename from sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py rename to sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py index 977cff937d..9a7966bbd4 100644 --- a/sub-packages/bionemo-contrib/src/bionemo/contrib/model/molecule/diffdock/utils/torus.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py new file mode 100644 index 0000000000..671c9c2882 --- /dev/null +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from enum import Enum, auto +from functools import partial +from typing import Any, Iterable + +import lightning as L +import pytest +import torch +from torch_geometric.loader.dataloader import Collater +from webdataset.filters import batched + +from bionemo.core.data.datamodule import PickledDataWDS, Split +from bionemo.diffdock.utils.data import SelectPoseAndLabelData, SizeAwareBatching, estimate_size +from bionemo.diffdock.utils.diffusion import GenerateNoise, t_to_sigma + + +@pytest.fixture(scope="module") +def get_path(request): + dir_test = os.path.dirname(request.module.__file__) + dir_data = f"{dir_test}/test_data" + return dir_test, dir_data + + +class DiffDockModel(Enum): + score = auto() + confidence = auto() + + +@pytest.fixture(scope="module", params=list(DiffDockModel)) +def get_diffdock_heterodata(get_path, request): + _, dir_data = get_path + model = request.param + name_model = str(model).split(".")[-1] + dir_heterodata = f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/{name_model}_model" + suffix_heterodata = "heterodata.pyd" + names = { + Split.train: [ + "6t88", + "6vs3", + "6wtn", + "6yqv", + "7amc", + "7bmi", + "7cuo", + "7d5c", + "7din", + "7fha", + "7jnb", + "7k0v", + "7kb1", + "7km8", + "7l7c", + "7lcu", + "7msr", + "7my1", + "7n6f", + "7np6", + ], + Split.val: ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", "7rh3", "7rzl", "7sgv"], + Split.test: ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", "7wpw", "7xek", "7xij"], + } + return (dir_heterodata, suffix_heterodata, names, model) + + +def no_op_gen(it: Iterable[Any]): + yield from it + + +def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): + prefix_dir_tars_wds = tmp_path_factory.mktemp("diffdock_score_model_tars_wds").as_posix() + tr_sigma_min, tr_sigma_max = (0.1, 19) + rot_sigma_min, rot_sigma_max = (0.03, 1.55) + tor_sigma_min, tor_sigma_max = (0.0314, 3.14) + is_all_atom = False + no_torsion = False + sigma_t = partial( + t_to_sigma, tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max + ) + # webdataset pipeline + generateNoise = { + Split.train: [GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), no_op_gen], + Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=True), + Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), + } + local_batch_size = 2 + global_batch_size = 2 + size_cuda_mem = 0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20 + batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) + # WebLoader pipeline + pipelines_wdl_batch = { + Split.train: SizeAwareBatching(max_total_size=size_cuda_mem, size_fn=estimate_size, no_single_sample=True), + Split.val: [batch_pyg, no_op_gen], + Split.test: batch_pyg, + } + n_tars_wds = 4 + seed_rng_shfl = 822782392 + kwargs_dl = { + Split.train: {"num_workers": 2}, + Split.val: {"num_workers": 2}, + Split.test: {"num_workers": 2}, + } + data_module = PickledDataWDS( + dir_heterodata, + suffix_heterodata, + names, + prefix_dir_tars_wds, + global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds="heterographs", + pipeline_wds=generateNoise, + pipeline_prebatch_wld=pipelines_wdl_batch, + seed_rng_shfl=seed_rng_shfl, + kwargs_dl=kwargs_dl, + ) + return data_module, prefix_dir_tars_wds + + +def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): + prefix_dir_tars_wds = tmp_path_factory.mktemp("diffdock_confidence_model_tars_wds").as_posix() + # webdataset pipeline + rmsd_classification_cutoff = 2.0 + samples_per_complex = 7 + balance = False + is_all_atom = True + seed_rng_shfl = 822782392 + select_pose = SelectPoseAndLabelData( + rmsd_classification_cutoff, samples_per_complex, balance, is_all_atom, seed=seed_rng_shfl + ) + pipeline_wds = { + Split.train: select_pose, + Split.val: select_pose, + Split.test: select_pose, + } + local_batch_size = 2 + global_batch_size = 2 + batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) + # WebLoader pipeline + pipelines_wdl_batch = { + Split.train: batch_pyg, + Split.val: batch_pyg, + Split.test: batch_pyg, + } + n_tars_wds = 4 + kwargs_dl = { + Split.train: {"num_workers": 2}, + Split.val: {"num_workers": 2}, + Split.test: {"num_workers": 2}, + } + data_module = PickledDataWDS( + dir_heterodata, + suffix_heterodata, + names, + prefix_dir_tars_wds, + global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds="heterographs", + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipelines_wdl_batch, + seed_rng_shfl=seed_rng_shfl, + kwargs_dl=kwargs_dl, + ) + return data_module, prefix_dir_tars_wds + + +@pytest.fixture(scope="module") +def create_datamodule(tmp_path_factory, get_diffdock_heterodata): + dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata + if model == DiffDockModel.score: + return _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) + elif model == DiffDockModel.confidence: + return _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) + + +@pytest.fixture(scope="module") +def create_another_datamodule(tmp_path_factory, get_diffdock_heterodata): + dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata + if model == DiffDockModel.score: + return _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) + elif model == DiffDockModel.confidence: + return _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) + + +class ModelTestDiffDock(L.LightningModule): + def __init__(self) -> None: + super().__init__() + self._model = torch.nn.Linear(3, 3) + self._samples = {split: [] for split in Split} + + def forward(self, x): + return self._model(x["ligand"].pos) + + def training_step(self, batch): + self._samples[Split.train].append(batch.name) + loss = self(batch).sum() + return loss + + def validation_step(self, batch, batch_index): + self._samples[Split.val].append(batch.name) + return torch.zeros(1) + + def test_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + + def predict_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + return torch.zeros(1) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) + return optimizer + + +@pytest.fixture(scope="function") +def create_trainer_and_model(): + trainer = L.Trainer(max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1) + model = ModelTestDiffDock() + return trainer, model diff --git a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py similarity index 62% rename from sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py rename to sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py index b7d4cb56d2..4c52943634 100644 --- a/sub-packages/bionemo-contrib/tests/bionemo/contrib/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py @@ -13,52 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum, auto import glob -import pytest -import torch +from enum import Enum, auto import lightning +import pytest +import torch from torch_geometric.data import Batch, HeteroData -from bionemo.contrib.data.molecule.diffdock.datamodule import Split +from bionemo.core.data.datamodule import Split -@pytest.mark.parametrize("split", [s for s in Split]) +@pytest.mark.parametrize("split", list(Split)) def test_datamodule_init(split, get_diffdock_heterodata, create_datamodule): - name_split = str(split).split('.')[1] + name_split = str(split).split(".")[1] (_, _, names, model) = get_diffdock_heterodata data_module, prefix_dir_tars_wds = create_datamodule - assert data_module._n_samples[split] == len(names[split]),\ - f"Wrong {split}-set size for {model} model: "\ - f"expected {len(names[split])} "\ + assert data_module._n_samples[split] == len(names[split]), ( + f"Wrong {split}-set size for {model} model: " + f"expected {len(names[split])} " f"but got {data_module._n_samples[split]}" - assert data_module._dirs_tars_wds[split] ==\ - f"{prefix_dir_tars_wds}{name_split}",\ - f"Wrong tar files directory for {model} model: "\ - f"expected {prefix_dir_tars_wds}{split} "\ + ) + assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( + f"Wrong tar files directory for {model} model: " + f"expected {prefix_dir_tars_wds}{split} " f"but got {data_module._dirs_tars_wds[split]}" + ) -@pytest.mark.parametrize("split", [s for s in Split]) +@pytest.mark.parametrize("split", list(Split)) def test_datamodule_prepare_data(split, create_datamodule): data_module, _ = create_datamodule # LightningDataModule.prepare_data() is supposed to be called from the main # process in a Lightning-managed multi-process context so we can call it in # a single process data_module.prepare_data() - files_tars = sorted(glob.glob( - f"{data_module._dirs_tars_wds[split]}/"\ - f"{data_module._prefix_tars_wds}-*.tar")) - assert len(files_tars) >= data_module._n_tars_wds,\ - f"Wrong num of {split}-set tar files: "\ - f"expected {data_module._n_tars_wds} "\ - f"got {len(files_tars)}" + files_tars = sorted(glob.glob(f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar")) + assert len(files_tars) >= data_module._n_tars_wds, ( + f"Wrong num of {split}-set tar files: " f"expected {data_module._n_tars_wds} " f"got {len(files_tars)}" + ) -@pytest.mark.parametrize("split", [s for s in Split]) +@pytest.mark.parametrize("split", list(Split)) def test_datamodule_setup_dataset(split, create_datamodule, create_another_datamodule): - data_modules= [create_datamodule[0], create_another_datamodule[0]] + data_modules = [create_datamodule[0], create_another_datamodule[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: @@ -71,38 +69,40 @@ def test_datamodule_setup_dataset(split, create_datamodule, create_another_datam names = [] pos_ligand = [] for sample in m._dataset[split]: - assert isinstance(sample, HeteroData),\ - "Sample yield from dataset is not PyG HeteroData" + assert isinstance(sample, HeteroData), "Sample yield from dataset is not PyG HeteroData" names.append(sample.name) pos_ligand.append(sample["ligand"].pos) lists_complex_name.append(names) lists_pos_ligand.append(pos_ligand) - assert len(lists_complex_name[0]) > 0,\ - "No names in {split} dataset" - assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent sample name in {split}-set from data module instances: "\ - f"{lists_complex_name[0]} \n\nvs.\n\n"\ + assert len(lists_complex_name[0]) > 0, "No names in {split} dataset" + assert lists_complex_name[0] == lists_complex_name[1], ( + f"Inconsistent sample name in {split}-set from data module instances: " + f"{lists_complex_name[0]} \n\nvs.\n\n" f"{lists_complex_name[1]}" + ) assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ - f"Inconsistent number of ligand position in {split}-set from data "\ - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]), ( + f"Inconsistent number of ligand position in {split}-set from data " + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n" f"{len(lists_pos_ligand[1])}" + ) for i in range(len(lists_pos_ligand[0])): pos_0 = lists_pos_ligand[0][i] pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close(pos_0, pos_1, - msg=lambda m : - f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}") + torch.testing.assert_close( + pos_0, + pos_1, + msg=lambda m: f"Inconsistent ligand position in the " + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}", + ) -@pytest.mark.parametrize("split", [s for s in Split]) +@pytest.mark.parametrize("split", list(Split)) def test_datamodule_setup_dataloader(split, create_datamodule, create_another_datamodule): - data_modules= [create_datamodule[0], create_another_datamodule[0]] + data_modules = [create_datamodule[0], create_another_datamodule[0]] lists_complex_name = [] lists_pos_ligand = [] for m in data_modules: @@ -126,36 +126,36 @@ def test_datamodule_setup_dataloader(split, create_datamodule, create_another_da assert loader is not None, "dataloader not instantated" for samples in loader: # PyG's HeteroDataBatch is Batch inherited from HeteroData - assert isinstance(samples, Batch),\ - f"Sample object is not PyG Batch" - assert isinstance(samples, HeteroData),\ - f"Sample object is not PyG HeteroData" + assert isinstance(samples, Batch), "Sample object is not PyG Batch" + assert isinstance(samples, HeteroData), "Sample object is not PyG HeteroData" names.append(samples.name) pos_ligand.append(samples["ligand"].pos) lists_complex_name.append(names) lists_pos_ligand.append(pos_ligand) - assert len(lists_complex_name[0]) > 0,\ - "No names in {split} dataloader" - assert lists_complex_name[0] == lists_complex_name[1],\ - f"Inconsistent sample name in {split}-set from data module instances: "\ - f"{lists_complex_name[0]} \n\nvs.\n\n"\ + assert len(lists_complex_name[0]) > 0, "No names in {split} dataloader" + assert lists_complex_name[0] == lists_complex_name[1], ( + f"Inconsistent sample name in {split}-set from data module instances: " + f"{lists_complex_name[0]} \n\nvs.\n\n" f"{lists_complex_name[1]}" + ) - assert len(lists_pos_ligand[0]) > 0,\ - "No ligand position found in dataloader" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]),\ - f"Inconsistent number of ligand position in {split}-set from data "\ - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n"\ + assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataloader" + assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]), ( + f"Inconsistent number of ligand position in {split}-set from data " + f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n" f"{len(lists_pos_ligand[1])}" + ) for i in range(len(lists_pos_ligand[0])): pos_0 = lists_pos_ligand[0][i] pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close(pos_0, pos_1, - msg=lambda m : - f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}") + torch.testing.assert_close( + pos_0, + pos_1, + msg=lambda m: f"Inconsistent ligand position in the " + f"{i}'th sample/batch of {split}-set " + f"between two data module instances:\n\n{m}", + ) class Stage(Enum): @@ -165,11 +165,9 @@ class Stage(Enum): predict = auto() -@pytest.mark.parametrize("stage", [s for s in Stage]) -def test_datamodule_in_lightning(stage, create_datamodule, - create_another_datamodule, - create_trainer_and_model): - data_modules= [create_datamodule[0], create_another_datamodule[0]] +@pytest.mark.parametrize("stage", list(Stage)) +def test_datamodule_in_lightning(stage, create_datamodule, create_another_datamodule, create_trainer_and_model): + data_modules = [create_datamodule[0], create_another_datamodule[0]] trainer, model = create_trainer_and_model # get the list of samples from the loader lightning.seed_everything(2823828) @@ -186,8 +184,7 @@ def test_datamodule_in_lightning(stage, create_datamodule, name_stage = str(stage).split(".")[-1] data_modules[0].setup(name_stage) # get the list of samples from the workflow - get_dataloader = getattr(data_modules[0], - f"{str(split).split('.')[-1]}_dataloader") + get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() samples = [] for sample in loader: From b622a2cab02a585b9846dc35b897cf5b4a3d45f2 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 13 Aug 2024 22:19:38 +0000 Subject: [PATCH 36/70] Enhancement: rename WDSModule -> WebDataModule --- .../bionemo-core/src/bionemo/core/data/datamodule.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index 5c66173fc2..cadff82a36 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -30,7 +30,7 @@ class Split(Enum): test = auto() -class WDSModule(L.LightningDataModule): +class WebDataModule(L.LightningDataModule): """lightning data module for using webdataset tar files to setup dataset and dataloader. This data module takes a dictionary: Split -> tar file directory. In its setup() function, it creates the webdataset object @@ -224,7 +224,7 @@ def predict_dataloader(self) -> wds.WebLoader: return self._setup_dataloader(Split.test) -class PickledDataWDS(WDSModule): +class PickledDataWDS(WebDataModule): """lightning APIs to process pickled data into webdataset tar files and setup dataset and dataloader. This data module takes a directory of pickled data files, data filename prefixes for train/val/test splits, data filename @@ -260,12 +260,12 @@ def __init__( webdataset tar files. The actual directories storing the train, val and test sets will be suffixed with "train", "val" and "test" respectively. - *args: arguments passed to the parent WDSModule + *args: arguments passed to the parent WebDataModule Kwargs: n_tars_wds (int): attempt to create at least this number of webdataset shards - **kwargs: arguments passed to the parent WDSModule + **kwargs: arguments passed to the parent WebDataModule """ From 1ffde9597fc7ec931efd056b384e9565ebd10180 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 13 Aug 2024 23:43:56 +0000 Subject: [PATCH 37/70] Enhancement: actually support Iterable[str] as suffix_keys_wds --- .../src/bionemo/core/data/datamodule.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index cadff82a36..56728755c7 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -16,7 +16,7 @@ import glob import random from enum import Enum, auto -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union, get_args import lightning as L import webdataset as wds @@ -42,7 +42,7 @@ def __init__( self, dirs_tars_wds: Dict[Split, str], n_samples: Dict[Split, int], - suffix_keys_wds: Iterable[str], + suffix_keys_wds: Union[str, Iterable[str]], global_batch_size: int, prefix_tars_wds: str = "wdshards", pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, @@ -57,10 +57,10 @@ def __init__( directory that contains the webdataset tar files for each split n_samples (Dict[Split, int]): input dictionary: Split -> number of data samples for each split - suffix_keys_wds (Iterable): a set of keys each corresponding to a - data object in the webdataset tar file dictionary. The data - objects of these keys will be extracted and tupled for each - sample in the tar files + suffix_keys_wds (Union[str, Iterable[str]]): a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files global_batch_size (int): size of batch summing across nodes in Data Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: this data module doesn't rely on the input `global_batch_size` @@ -110,6 +110,12 @@ def __init__( self._n_samples = n_samples self._global_batch_size = global_batch_size + + + if not isinstance(suffix_keys_wds, + get_args(Union[str, Iterable[str]])): + raise TypeError("suffix_keys_wds can only be str or Iterable[str]") + self._suffix_keys_wds = suffix_keys_wds self._prefix_tars_wds = prefix_tars_wds @@ -148,8 +154,13 @@ def _setup_wds(self, split: Split) -> wds.WebDataset: dataset = ( wds.WebDataset(urls, shardshuffle=is_train, nodesplitter=wds.split_by_node, seed=self._seed_rng_shfl) .decode() - .extract_keys(f"*.{self._suffix_keys_wds}") ) + if isinstance(self._suffix_keys_wds, str): + dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}") + else: + dataset = dataset.extract_keys(*[f"*.{key}" for key in + self._suffix_keys_wds]) + if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: if isinstance(self._pipeline_wds[split], Iterable): dataset = dataset.compose(*self._pipeline_wds[split]) From 594823056fd471df7568aaaf710b84881c0353f7 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 02:21:58 +0000 Subject: [PATCH 38/70] Test: WebDataModule with mock-up dataset and model --- .../tests/bionemo/core/data/conftest.py | 121 ++++++++++++++++ .../bionemo/core/data/test_webdatamodule.py | 136 ++++++++++++++++++ 2 files changed, 257 insertions(+) create mode 100644 sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py create mode 100644 sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py new file mode 100644 index 0000000000..b34917e63c --- /dev/null +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from enum import Enum, auto +from functools import partial +from typing import Any, Iterable + +import lightning as L +import pytest +import torch +from webdataset.filters import batched + +from bionemo.core.data.datamodule import WebDataModule, Split + + +@pytest.fixture(scope="module") +def get_path(request): + path_test = Path(request.module.__file__).resolve() + dir_test = path_test.parents[0] + dir_data = path_test.parents[6] / "test_data" / "bionemo" / "core" / "data" / "webdatamodule" + return str(dir_test), str(dir_data) + + +def _create_webdatamodule(dir_tars_wds): + suffix_keys_wds = "tensor.pyd" + local_batch_size = 2 + global_batch_size = 2 + prefix_tars_wds = "tensor" + seed_rng_shfl = 82838392 + + dirs_tars_wds = { split : dir_tars_wds for split in Split } + + n_samples = { split : 10 for split in Split } + + batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) + + pipeline_wds = { + split : lambda source : (sample for (sample,) in source) + for split in Split + } + + pipeline_prebatch_wld = { + split : batch for split in Split + } + + kwargs_dl = { + split : {"num_workers": 2} for split in Split + } + + data_module = WebDataModule(dirs_tars_wds, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + seed_rng_shfl=seed_rng_shfl, kwargs_dl=kwargs_dl) + + return data_module, dir_tars_wds + + +@pytest.fixture(scope="module") +def create_webdatamodule(get_path): + _, dir_tars_wds = get_path + return _create_webdatamodule(dir_tars_wds) + + +@pytest.fixture(scope="module") +def create_another_webdatamodule(get_path): + _, dir_tars_wds = get_path + return _create_webdatamodule(dir_tars_wds) + + +class ModelTestWebDataModule(L.LightningModule): + def __init__(self) -> None: + super().__init__() + self._model = torch.nn.Linear(1, 1) + self._samples = {split: [] for split in Split} + + def forward(self, x): + return self._model(x.float()) + + def training_step(self, batch): + self._samples[Split.train].append(batch.name) + loss = self(batch).sum() + return loss + + def validation_step(self, batch, batch_index): + self._samples[Split.val].append(batch.name) + return torch.zeros(1) + + def test_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + + def predict_step(self, batch, batch_index): + self._samples[Split.test].append(batch.name) + return torch.zeros(1) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) + return optimizer + + +@pytest.fixture(scope="function") +def create_trainer_and_model(): + trainer = L.Trainer(max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1) + model = ModelTestWebDataModule() + return trainer, model diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py new file mode 100644 index 0000000000..8e20ce2f48 --- /dev/null +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, auto + +import pytest + +import torch +import lightning as L + +from bionemo.core.data.datamodule import Split + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_init(split, create_webdatamodule): + data_module, prefix_dir_tars_wds = create_webdatamodule + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}", ( + f"Wrong tar files directory: " + f"expected {prefix_dir_tars_wds} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataset(split, create_webdatamodule, + create_another_webdatamodule): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors= [] + for sample in m._dataset[split]: + assert isinstance(sample, torch.Tensor),\ + "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close(torch.vstack(lists_tensors[0]), + torch.vstack(lists_tensors[1])) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataloader(split, create_webdatamodule, + create_another_webdatamodule): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + loader = None + if split == Split.train: + loader = m.train_dataloader() + elif split == Split.val: + loader = m.val_dataloader() + elif split == Split.test: + loader = m.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + for samples in loader: + # PyG's HeteroDataBatch is Batch inherited from HeteroData + assert isinstance(samples, torch.Tensor),\ + "Sample object is not torch.Tensor" + tensors.append(samples) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataloader" + torch.testing.assert_close(torch.vstack(lists_tensors[0]), + torch.vstack(lists_tensors[1])) + + +class Stage(Enum): + fit = auto() + validate = auto() + test = auto() + predict = auto() + + +@pytest.mark.parametrize("stage", list(Stage)) +def test_webdatamodule_in_lightning(stage, create_webdatamodule, + create_another_webdatamodule, + create_trainer_and_model): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + trainer, model = create_trainer_and_model + # get the list of samples from the loader + L.seed_everything(2823828) + data_modules[0].prepare_data() + split = None + if stage == Stage.fit: + split = Split.train + elif stage == Stage.validate: + split = Split.val + elif stage == Stage.test or stage == Stage.predict: + split = Split.test + else: + raise RuntimeError(f"{stage} stage not implemented") + name_stage = str(stage).split(".")[-1] + data_modules[0].setup(name_stage) + # get the list of samples from the workflow + get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") + loader = get_dataloader() + samples = [] + for sample in loader: + samples.append(sample.name) + L.seed_everything(2823828) + workflow = getattr(trainer, name_stage) + workflow(model, data_modules[1]) + assert model._samples[split] == samples From 78ed1897aa0703504660bdbae61927e8f39bd8a8 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 03:17:28 +0000 Subject: [PATCH 39/70] Enhancement: remove hard-coded shuffle from workflow --- .../src/bionemo/core/data/datamodule.py | 6 +--- .../tests/bionemo/core/data/conftest.py | 16 +++++++--- .../tests/bionemo/diffdock/data/conftest.py | 29 ++++++++++--------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index 56728755c7..bf426e1d2b 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -166,8 +166,6 @@ def _setup_wds(self, split: Split) -> wds.WebDataset: dataset = dataset.compose(*self._pipeline_wds[split]) else: dataset = dataset.compose(self._pipeline_wds[split]) - if is_train: - dataset = dataset.shuffle(size=16, rng=random.Random(self._seed_rng_shfl)) return dataset def setup(self, stage: str) -> None: @@ -208,9 +206,7 @@ def _setup_dataloader(self, split: Split) -> wds.WebLoader: n_samples = self._n_samples[split] n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size kwargs = self._kwargs_dl[split] if self._kwargs_dl is not None else None - loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {})).shuffle( - 5000, rng=random.Random(self._seed_rng_shfl) - ) + loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {})) if self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None: if isinstance(self._pipeline_prebatch_wld[split], Iterable): diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index b34917e63c..c5ff1b8f4d 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import random from pathlib import Path from enum import Enum, auto from functools import partial @@ -22,7 +23,7 @@ import lightning as L import pytest import torch -from webdataset.filters import batched +from webdataset.filters import batched, shuffle from bionemo.core.data.datamodule import WebDataModule, Split @@ -49,13 +50,20 @@ def _create_webdatamodule(dir_tars_wds): batch = batched(local_batch_size, collation_fn=lambda list_samples : torch.vstack(list_samples)) + untuple = lambda source : (sample for (sample,) in source) + pipeline_wds = { - split : lambda source : (sample for (sample,) in source) - for split in Split + Split.train : [untuple, shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl))], + Split.val : untuple, + Split.test : untuple } pipeline_prebatch_wld = { - split : batch for split in Split + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch } kwargs_dl = { diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 671c9c2882..8e34632ea8 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -16,13 +16,14 @@ import os from enum import Enum, auto from functools import partial +import random from typing import Any, Iterable import lightning as L import pytest import torch from torch_geometric.loader.dataloader import Collater -from webdataset.filters import batched +from webdataset.filters import batched, shuffle from bionemo.core.data.datamodule import PickledDataWDS, Split from bionemo.diffdock.utils.data import SelectPoseAndLabelData, SizeAwareBatching, estimate_size @@ -77,10 +78,6 @@ def get_diffdock_heterodata(get_path, request): return (dir_heterodata, suffix_heterodata, names, model) -def no_op_gen(it: Iterable[Any]): - yield from it - - def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): prefix_dir_tars_wds = tmp_path_factory.mktemp("diffdock_score_model_tars_wds").as_posix() tr_sigma_min, tr_sigma_max = (0.1, 19) @@ -91,9 +88,12 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix sigma_t = partial( t_to_sigma, tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max ) + seed_rng_shfl = 822782392 # webdataset pipeline - generateNoise = { - Split.train: [GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), no_op_gen], + pipeline_wds = { + Split.train: [GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), + shuffle(len(names[Split.train]), + rng=random.Random(seed_rng_shfl))], Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=True), Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), } @@ -103,12 +103,14 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) # WebLoader pipeline pipelines_wdl_batch = { - Split.train: SizeAwareBatching(max_total_size=size_cuda_mem, size_fn=estimate_size, no_single_sample=True), - Split.val: [batch_pyg, no_op_gen], + Split.train: [shuffle(40, rng=random.Random(seed_rng_shfl)), + SizeAwareBatching(max_total_size=size_cuda_mem, + size_fn=estimate_size, + no_single_sample=True)], + Split.val: batch_pyg, Split.test: batch_pyg, } n_tars_wds = 4 - seed_rng_shfl = 822782392 kwargs_dl = { Split.train: {"num_workers": 2}, Split.val: {"num_workers": 2}, @@ -122,7 +124,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix global_batch_size, n_tars_wds=n_tars_wds, prefix_tars_wds="heterographs", - pipeline_wds=generateNoise, + pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, seed_rng_shfl=seed_rng_shfl, kwargs_dl=kwargs_dl, @@ -142,7 +144,8 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s rmsd_classification_cutoff, samples_per_complex, balance, is_all_atom, seed=seed_rng_shfl ) pipeline_wds = { - Split.train: select_pose, + Split.train: [select_pose, shuffle(len(names[Split.train]), + rng=random.Random(seed_rng_shfl))], Split.val: select_pose, Split.test: select_pose, } @@ -151,7 +154,7 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) # WebLoader pipeline pipelines_wdl_batch = { - Split.train: batch_pyg, + Split.train: [shuffle(40, rng=random.Random(seed_rng_shfl)), batch_pyg], Split.val: batch_pyg, Split.test: batch_pyg, } From 01bb63e3ab85d8db36ff77601d79ed7afe21c128 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 03:20:28 +0000 Subject: [PATCH 40/70] Enhancement: rename kwargs_dl -> kwargs_wld --- .../bionemo-core/src/bionemo/core/data/datamodule.py | 8 ++++---- .../bionemo-core/tests/bionemo/core/data/conftest.py | 4 ++-- .../tests/bionemo/diffdock/data/conftest.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index bf426e1d2b..e8d5f992b2 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -48,7 +48,7 @@ def __init__( pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, seed_rng_shfl: int = 0, - kwargs_dl: Optional[Dict[Split, Dict[str, str]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, str]]] = None, ): """constructor @@ -91,7 +91,7 @@ def __init__( yield from the WebLoader seed_rng_shfl (int): seed to the random number generators used in data loading time for shuffling - kwargs_dl (Optional[Dict[Split, Dict[str, str]]]): kwargs for data + kwargs_wld (Optional[Dict[Split, Dict[str, str]]]): kwargs for data loader, e.g., num_workers, of each split @@ -124,7 +124,7 @@ def __init__( self._seed_rng_shfl = seed_rng_shfl - self._kwargs_dl = kwargs_dl + self._kwargs_wld = kwargs_wld # to be created later in setup self._dataset = {} @@ -205,7 +205,7 @@ def _setup_dataloader(self, split: Split) -> wds.WebLoader: dataset = self._dataset[split] n_samples = self._n_samples[split] n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size - kwargs = self._kwargs_dl[split] if self._kwargs_dl is not None else None + kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {})) if self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None: diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index c5ff1b8f4d..36831b304f 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -66,7 +66,7 @@ def _create_webdatamodule(dir_tars_wds): Split.test : batch } - kwargs_dl = { + kwargs_wld = { split : {"num_workers": 2} for split in Split } @@ -75,7 +75,7 @@ def _create_webdatamodule(dir_tars_wds): prefix_tars_wds=prefix_tars_wds, pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipeline_prebatch_wld, - seed_rng_shfl=seed_rng_shfl, kwargs_dl=kwargs_dl) + seed_rng_shfl=seed_rng_shfl, kwargs_wld=kwargs_wld) return data_module, dir_tars_wds diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 8e34632ea8..548d33f745 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -111,7 +111,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix Split.test: batch_pyg, } n_tars_wds = 4 - kwargs_dl = { + kwargs_wld = { Split.train: {"num_workers": 2}, Split.val: {"num_workers": 2}, Split.test: {"num_workers": 2}, @@ -127,7 +127,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, seed_rng_shfl=seed_rng_shfl, - kwargs_dl=kwargs_dl, + kwargs_wld=kwargs_wld, ) return data_module, prefix_dir_tars_wds @@ -159,7 +159,7 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s Split.test: batch_pyg, } n_tars_wds = 4 - kwargs_dl = { + kwargs_wld = { Split.train: {"num_workers": 2}, Split.val: {"num_workers": 2}, Split.test: {"num_workers": 2}, @@ -175,7 +175,7 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, seed_rng_shfl=seed_rng_shfl, - kwargs_dl=kwargs_dl, + kwargs_wld=kwargs_wld, ) return data_module, prefix_dir_tars_wds From c34dd0f1d4ab7985c64f88c624567caef1d8fbad Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 03:33:21 +0000 Subject: [PATCH 41/70] Enhancement: replace hard-coded webdatast kwargs with user input --- .../src/bionemo/core/data/datamodule.py | 19 ++++++++++--------- .../tests/bionemo/core/data/conftest.py | 11 ++++++++++- .../tests/bionemo/diffdock/data/conftest.py | 17 +++++++++++++++-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index e8d5f992b2..4f916559e7 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -47,8 +47,8 @@ def __init__( prefix_tars_wds: str = "wdshards", pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, - seed_rng_shfl: int = 0, - kwargs_wld: Optional[Dict[Split, Dict[str, str]]] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None, ): """constructor @@ -89,10 +89,10 @@ def __init__( seuqnence of such iterators. For example, this can be used for batching the samples. NOTE: this is applied before batching is yield from the WebLoader - seed_rng_shfl (int): seed to the random number generators used in - data loading time for shuffling - kwargs_wld (Optional[Dict[Split, Dict[str, str]]]): kwargs for data - loader, e.g., num_workers, of each split + kwargs_wds (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebDataset.__init__() + kwargs_wld (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split """ @@ -122,10 +122,10 @@ def __init__( self._pipeline_wds = pipeline_wds self._pipeline_prebatch_wld = pipeline_prebatch_wld - self._seed_rng_shfl = seed_rng_shfl - self._kwargs_wld = kwargs_wld + self._kwargs_wds = kwargs_wds + # to be created later in setup self._dataset = {} @@ -151,8 +151,9 @@ def _setup_wds(self, split: Split) -> wds.WebDataset: raise RuntimeError(f"_setup_wds() is called with {split} " f"split that doesn't have the input tar dir") is_train = split == Split.train urls = sorted(glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar")) + kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None dataset = ( - wds.WebDataset(urls, shardshuffle=is_train, nodesplitter=wds.split_by_node, seed=self._seed_rng_shfl) + wds.WebDataset(urls, **(kwargs if kwargs is not None else {})) .decode() ) if isinstance(self._suffix_keys_wds, str): diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index 36831b304f..79e7d54db1 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -23,6 +23,7 @@ import lightning as L import pytest import torch +import webdataset as wds from webdataset.filters import batched, shuffle from bionemo.core.data.datamodule import WebDataModule, Split @@ -66,6 +67,13 @@ def _create_webdatamodule(dir_tars_wds): Split.test : batch } + kwargs_wds = { + split : {'shardshuffle' : split == Split.train, + 'nodesplitter' : wds.split_by_node, + 'seed' : seed_rng_shfl} + for split in Split + } + kwargs_wld = { split : {"num_workers": 2} for split in Split } @@ -75,7 +83,8 @@ def _create_webdatamodule(dir_tars_wds): prefix_tars_wds=prefix_tars_wds, pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipeline_prebatch_wld, - seed_rng_shfl=seed_rng_shfl, kwargs_wld=kwargs_wld) + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) return data_module, dir_tars_wds diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 548d33f745..729a43079e 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -23,6 +23,7 @@ import pytest import torch from torch_geometric.loader.dataloader import Collater +import webdataset as wds from webdataset.filters import batched, shuffle from bionemo.core.data.datamodule import PickledDataWDS, Split @@ -111,6 +112,12 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix Split.test: batch_pyg, } n_tars_wds = 4 + kwargs_wds = { + split : {'shardshuffle' : split == Split.train, + 'nodesplitter' : wds.split_by_node, + 'seed' : seed_rng_shfl} + for split in Split + } kwargs_wld = { Split.train: {"num_workers": 2}, Split.val: {"num_workers": 2}, @@ -126,7 +133,7 @@ def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix prefix_tars_wds="heterographs", pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, - seed_rng_shfl=seed_rng_shfl, + kwargs_wds=kwargs_wds, kwargs_wld=kwargs_wld, ) return data_module, prefix_dir_tars_wds @@ -159,6 +166,12 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s Split.test: batch_pyg, } n_tars_wds = 4 + kwargs_wds = { + split : {'shardshuffle' : split == Split.train, + 'nodesplitter' : wds.split_by_node, + 'seed' : seed_rng_shfl} + for split in Split + } kwargs_wld = { Split.train: {"num_workers": 2}, Split.val: {"num_workers": 2}, @@ -174,7 +187,7 @@ def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, s prefix_tars_wds="heterographs", pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipelines_wdl_batch, - seed_rng_shfl=seed_rng_shfl, + kwargs_wds=kwargs_wds, kwargs_wld=kwargs_wld, ) return data_module, prefix_dir_tars_wds From 40044fced02604f56e9d93ca6c98f0de17a76e7c Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 03:42:24 +0000 Subject: [PATCH 42/70] Test: move data into bionemo2_root/test_data --- .../tests/bionemo/core/data/conftest.py | 6 +----- .../tests/bionemo/diffdock/data/conftest.py | 13 +++++++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index 79e7d54db1..50ce1e641f 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random from pathlib import Path -from enum import Enum, auto -from functools import partial -from typing import Any, Iterable import lightning as L import pytest @@ -33,7 +29,7 @@ def get_path(request): path_test = Path(request.module.__file__).resolve() dir_test = path_test.parents[0] - dir_data = path_test.parents[6] / "test_data" / "bionemo" / "core" / "data" / "webdatamodule" + dir_data = path_test.parents[6] / "test_data" / "bionemo-core" / "data" / "webdatamodule" return str(dir_test), str(dir_data) diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 729a43079e..5e2fcb2cf0 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +from pathlib import Path from enum import Enum, auto from functools import partial import random -from typing import Any, Iterable import lightning as L import pytest @@ -33,9 +32,11 @@ @pytest.fixture(scope="module") def get_path(request): - dir_test = os.path.dirname(request.module.__file__) - dir_data = f"{dir_test}/test_data" - return dir_test, dir_data + path_test = Path(request.module.__file__).resolve() + dir_test = path_test.parents[0] + dir_data = path_test.parents[6] / "test_data" / \ + "bionemo-diffdock" / "data" / "pyg_heterodata_pickled" + return str(dir_test), str(dir_data) class DiffDockModel(Enum): @@ -48,7 +49,7 @@ def get_diffdock_heterodata(get_path, request): _, dir_data = get_path model = request.param name_model = str(model).split(".")[-1] - dir_heterodata = f"{dir_data}/molecule/diffdock/pyg_heterodata_pickled/{name_model}_model" + dir_heterodata = f"{dir_data}/{name_model}_model" suffix_heterodata = "heterodata.pyd" names = { Split.train: [ From 6f0d4b207b47953d6184d451fbee531249159814 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 11:59:55 -0700 Subject: [PATCH 43/70] Doc: docstring and readme.md the README.md is updated using ``` pydoc-markdown -m datamodule -I src/bionemo/core/data >> README.md ``` --- sub-packages/bionemo-core/README.md | 342 ++++++++++++++++++ .../src/bionemo/core/data/datamodule.py | 218 +++++++++-- 2 files changed, 534 insertions(+), 26 deletions(-) diff --git a/sub-packages/bionemo-core/README.md b/sub-packages/bionemo-core/README.md index ad93dbee9e..1157f39e9b 100644 --- a/sub-packages/bionemo-core/README.md +++ b/sub-packages/bionemo-core/README.md @@ -4,3 +4,345 @@ ```bash pip install -e . ``` +## WebDataModule + +```python +class WebDataModule(L.LightningDataModule) +``` + +A LightningDataModule for using webdataset tar files to setup dataset and +dataloader. This data module takes a dictionary: Split -> tar file +directory. In its setup() function, it creates the webdataset object +chaining up the input `pipeline_wds` workflow. In its +train/val/test_dataloader(), it creates the WebLoader object chaining up the +`pipeline_prebatch_wld` workflow + +Examples +-------- + +1. create the data module with input directory to webdataset tar files. +Depending on which of the downstream Lightning.Trainer methods are called, +e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or +`Trainer.predict()`, only a subset of the train, val and test splits need to +be specified in the various input options to the data module: + +- `Trainer.fit()` requires the `train` and `val` splits +- `Trainer.validate()` requires the `val` split +- `Trainer.test()` requires the `test` splits +- `Trainer.predict()` requires the `test` splits + +Here is an example of constructing the data module for `Trainer.fit()`: +``` +>>> from bionemo.core.data.datamodule import Split, WebDataModule +>>> +>>> tar_file_prefix = "shards" +>>> +>>> dirs_of_tar_files = { +>>> Split.train: "/path/to/train/split/tars", +>>> Split.val: "/path/to/val/split/tars", +>>> } +>>> +>>> n_samples { +>>> Split.train: 1000, +>>> Split.val: 100, +>>> } +>>> +>>> # this is the string to retrieve the corresponding data object from the +>>> # webdataset file (see +>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format +>>> # for details) +>>> suffix_keys_wds = "tensor.pyd" +>>> +>>> # see the API doc for the definition of global_batch_size +>>> global_batch_size = 16 +>>> +>>> seed = 27193781 +>>> +>>> # Specify the routines to process the samples in the WebDataset object. +>>> # The routine is a generator of an Iterable of generators that are chained +>>> # together by nested function calling. The following is equivalent of +>>> # defining a overall generator of `shuffle(untuple(...))` which +>>> # untuples the samples and shuffles them. See webdataset's Documentation +>>> # for details. +>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's +>>> # file parsing rule. +>>> +>>> untuple = lambda source : (sample for (sample,) in source) +>>> +>>> from webdatast import shuffle +>>> pipeline_wds = { +>>> Split.train : [untuple, shuffle(n_samples[Split.train], +>>> rng=random.Random(seed_rng_shfl))], +>>> Split.val: untuple +>>> } +>>> +>>> # Similarly the user can optionally define the processing routine on the +>>> # WebLoader (the dataloader of webdataset). +>>> # NOTE: these routines by default take unbatched sample as input so the +>>> # user can customize their batching routines here +>>> +>>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) +>>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } +>>> +>>> # the user can optionally specify the kwargs for WebDataset and +>>> # WebLoader +>>> +>>> kwargs_wds = { +>>> split : {'shardshuffle' : split == Split.train, +>>> 'nodesplitter' : wds.split_by_node, +>>> 'seed' : seed_rng_shfl} +>>> for split in Split +>>> } +>>> +>>> kwargs_wld = { +>>> split : {"num_workers": 2} for split in Split +>>> } +>>> +>>> # construct the data module +>>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) +``` + + + +#### \_\_init\_\_ + +```python +def __init__( + dirs_tars_wds: Dict[Split, str], + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + pipeline_prebatch_wld: Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None) +``` + +constructor + +**Arguments**: + +- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split +- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of + data samples for each split +- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files +- `global_batch_size` _int_ - size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: +- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], +- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader +- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebDataset.__init__() +- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. + +Returns: None + + + +#### setup + +```python +def setup(stage: str) -> None +``` + +This is called on all Lightning-managed nodes in a multi-node +training session + + +**Arguments**: + +- `stage` _str_ - "fit", "test" or "predict" +- `Returns` - None + +## PickledDataWDS + +```python +class PickledDataWDS(WebDataModule) +``` + +A LightningDataModule to process pickled data into webdataset tar files +and setup dataset and dataloader. This inherits the webdataset setup from +its parent module `WebDataModule`. This data module takes a directory of +pickled data files, data filename prefixes for train/val/test splits, data +filename suffixes and prepare webdataset tar files by globbing the specific +pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and +outputing to webdataset tar file with the dict structure: +``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } +``` +NOTE: this assumes only one pickled file is processed for each sample. In +its setup() function, it creates the webdataset object chaining up the input +`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the +WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + +Examples +-------- + +1. create the data module with a directory of pickle files and the file name +prefix thereof for different splits to used by `Lightning.Trainer.fit()` + +``` +>>> from bionemo.core.data.datamodule import Split, PickledDataWDS + +>>> dir_pickles = "/path/to/my/pickles/dir" + +>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the +>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the +>>> # validation dataset + +>>> suffix_pickles = "mydata.pt" + +>>> names_subset = { +>>> Split.train: [sample1, sample2], +>>> Split.val: [sample4, sample5], +>>> } + +>>> # the following setting will attempt to create at least 5 tar files in +>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + +>>> n_tars_wds = 5 +>>> prefix_tars_wds = "myshards" +>>> output_dir_tar_files = "/path/to/output/tars/dir" + +>>> # see the `WebDataModule` API doc for the definition of global_batch_size +>>> global_batch_size = 16 + +>>> # user can optionally customize the data processing routines and kwargs used +>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + +>>> pipeline_wds = { Split.train: ... } + +>>> pipeline_prebatch_wld = { Split.train: ... } + +>>> kwargs_wds = { Split.train: ..., Split.val: ... } + +>>> kwargs_wld = { Split.train: ..., Split.val: ... } + +>>> # create the data module +>>> data_module = PickledDataWDS( +>>> dir_pickles, +>>> suffix_pickles, +>>> names_subset, +>>> output_dir_tar_files, +>>> global_batch_size, # `WebDataModule` args +>>> n_tars_wds=n_tars_wds, +>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs +>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs +>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs +>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs +>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs +>>> ) + +``` + + + +#### \_\_init\_\_ + +```python +def __init__(dir_pickles: str, + suffix_pickles: str, + names_subset: Dict[Split, List[str]], + prefix_dir_tars_wds: str, + *args, + n_tars_wds: Optional[int] = None, + **kwargs) +``` + +constructor + +**Arguments**: + +- `dir_pickles` _str_ - input directory of pickled data files +- `suffix_pickles` _str_ - filename suffix of the input data in + dir_pickles. This is also used as the key mapped to the + tarballed pickled object in the webdataset +- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split +- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. +- `*args` - arguments passed to the parent WebDataModule + + Kwargs: +- `n_tars_wds` _int_ - attempt to create at least this number of + webdataset shards +- `**kwargs` - arguments passed to the parent WebDataModule + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. The nesting +`pickles_to_tars` function goes through the data name prefixes in the +different splits, read the corresponding pickled file and output a +webdataset tar archive with the dict structure: {"__key__" : +name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + +Returns: None + diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index 4f916559e7..b8b82fafef 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -31,12 +31,111 @@ class Split(Enum): class WebDataModule(L.LightningDataModule): - """lightning data module for using webdataset tar files to setup dataset and + """A LightningDataModule for using webdataset tar files to setup dataset and dataloader. This data module takes a dictionary: Split -> tar file directory. In its setup() function, it creates the webdataset object chaining up the input `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the WebLoader object chaining up the - `pipeline_prebatch_wld` workflow""" + `pipeline_prebatch_wld` workflow + + Examples + -------- + + 1. create the data module with input directory to webdataset tar files. + Depending on which of the downstream Lightning.Trainer methods are called, + e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or + `Trainer.predict()`, only a subset of the train, val and test splits need to + be specified in the various input options to the data module: + + - `Trainer.fit()` requires the `train` and `val` splits + - `Trainer.validate()` requires the `val` split + - `Trainer.test()` requires the `test` splits + - `Trainer.predict()` requires the `test` splits + + Here is an example of constructing the data module for `Trainer.fit()`: + ``` + >>> from bionemo.core.data.datamodule import Split, WebDataModule + >>> + >>> tar_file_prefix = "shards" + >>> + >>> dirs_of_tar_files = { + >>> Split.train: "/path/to/train/split/tars", + >>> Split.val: "/path/to/val/split/tars", + >>> } + >>> + >>> n_samples { + >>> Split.train: 1000, + >>> Split.val: 100, + >>> } + >>> + >>> # this is the string to retrieve the corresponding data object from the + >>> # webdataset file (see + >>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format + >>> # for details) + >>> suffix_keys_wds = "tensor.pyd" + >>> + >>> # see the API doc for the definition of global_batch_size + >>> global_batch_size = 16 + >>> + >>> seed = 27193781 + >>> + >>> # Specify the routines to process the samples in the WebDataset object. + >>> # The routine is a generator of an Iterable of generators that are chained + >>> # together by nested function calling. The following is equivalent of + >>> # defining a overall generator of `shuffle(untuple(...))` which + >>> # untuples the samples and shuffles them. See webdataset's Documentation + >>> # for details. + >>> # NOTE: the `untuple` is almost always necessary due to the webdataset's + >>> # file parsing rule. + >>> + >>> untuple = lambda source : (sample for (sample,) in source) + >>> + >>> from webdatast import shuffle + >>> pipeline_wds = { + >>> Split.train : [untuple, shuffle(n_samples[Split.train], + >>> rng=random.Random(seed_rng_shfl))], + >>> Split.val: untuple + >>> } + >>> + >>> # Similarly the user can optionally define the processing routine on the + >>> # WebLoader (the dataloader of webdataset). + >>> # NOTE: these routines by default take unbatched sample as input so the + >>> # user can customize their batching routines here + >>> + >>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) + >>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } + >>> + >>> # the user can optionally specify the kwargs for WebDataset and + >>> # WebLoader + >>> + >>> kwargs_wds = { + >>> split : {'shardshuffle' : split == Split.train, + >>> 'nodesplitter' : wds.split_by_node, + >>> 'seed' : seed_rng_shfl} + >>> for split in Split + >>> } + >>> + >>> kwargs_wld = { + >>> split : {"num_workers": 2} for split in Split + >>> } + >>> + >>> # construct the data module + >>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) + ``` + + """ def __init__( self, @@ -233,22 +332,88 @@ def predict_dataloader(self) -> wds.WebLoader: class PickledDataWDS(WebDataModule): - """lightning APIs to process pickled data into webdataset tar files and - setup dataset and dataloader. This data module takes a directory of pickled - data files, data filename prefixes for train/val/test splits, data filename - suffixes and prepare webdataset tar files by globbing the specific pickeld - data files {dir_pickled}/{name_subset[split]}.{suffix_pickled} and outputing - to webdataset tar file with the dict structure: {"__key__" : - name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. NOTE: this - assumes only one pickled file is processed for each sample. In its setup() - function, it creates the webdataset object chaining up the input + """A LightningDataModule to process pickled data into webdataset tar files + and setup dataset and dataloader. This inherits the webdataset setup from + its parent module `WebDataModule`. This data module takes a directory of + pickled data files, data filename prefixes for train/val/test splits, data + filename suffixes and prepare webdataset tar files by globbing the specific + pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and + outputing to webdataset tar file with the dict structure: + ``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } + ``` + NOTE: this assumes only one pickled file is processed for each sample. In + its setup() function, it creates the webdataset object chaining up the input `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the - WebLoader object chaining up the `pipeline_prebatch_wld` workflow""" + WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + + Examples + -------- + + 1. create the data module with a directory of pickle files and the file name + prefix thereof for different splits to used by `Lightning.Trainer.fit()` + + ``` + >>> from bionemo.core.data.datamodule import Split, PickledDataWDS + + >>> dir_pickles = "/path/to/my/pickles/dir" + + >>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the + >>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the + >>> # validation dataset + + >>> suffix_pickles = "mydata.pt" + + >>> names_subset = { + >>> Split.train: [sample1, sample2], + >>> Split.val: [sample4, sample5], + >>> } + + >>> # the following setting will attempt to create at least 5 tar files in + >>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + + >>> n_tars_wds = 5 + >>> prefix_tars_wds = "myshards" + >>> output_dir_tar_files = "/path/to/output/tars/dir" + + >>> # see the `WebDataModule` API doc for the definition of global_batch_size + >>> global_batch_size = 16 + + >>> # user can optionally customize the data processing routines and kwargs used + >>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + + >>> pipeline_wds = { Split.train: ... } + + >>> pipeline_prebatch_wld = { Split.train: ... } + + >>> kwargs_wds = { Split.train: ..., Split.val: ... } + + >>> kwargs_wld = { Split.train: ..., Split.val: ... } + + >>> # create the data module + >>> data_module = PickledDataWDS( + >>> dir_pickles, + >>> suffix_pickles, + >>> names_subset, + >>> output_dir_tar_files, + >>> global_batch_size, # `WebDataModule` args + >>> n_tars_wds=n_tars_wds, + >>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs + >>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs + >>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs + >>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs + >>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs + >>> ) + + ``` + + """ def __init__( self, - dir_pickled: str, - suffix_pickled: str, + dir_pickles: str, + suffix_pickles: str, names_subset: Dict[Split, List[str]], prefix_dir_tars_wds: str, *args, @@ -258,12 +423,13 @@ def __init__( """constructor Args: - dir_pickled (str): input directory of pickled data files - suffix_pickled (str): filename suffix of the input data in - dir_pickled. This is also used as the key mapped to the + dir_pickles (str): input directory of pickled data files + suffix_pickles (str): filename suffix of the input data in + dir_pickles. This is also used as the key mapped to the tarballed pickled object in the webdataset - names_subset (Dict[Split, List[str]]): list of complex names to be - included in each of the split + names_subset (Dict[Split, List[str]]): list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split prefix_dir_tars_wds (str): directory name prefix to store the output webdataset tar files. The actual directories storing the train, val and test sets will be suffixed with "train", "val" and "test" @@ -280,13 +446,13 @@ def __init__( super().__init__( {split: f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" for split in names_subset.keys()}, {split: len(names_subset[split]) for split in names_subset.keys()}, - suffix_pickled, + suffix_pickles, *args, **kwargs, ) - self._dir_pickled = dir_pickled - self._suffix_pickled = suffix_pickled + self._dir_pickles = dir_pickles + self._suffix_pickles = suffix_pickles self._prefix_dir_tars_wds = prefix_dir_tars_wds self._names_subset = names_subset @@ -296,19 +462,19 @@ def __init__( def prepare_data(self) -> None: """This is called only by the main process by the Lightning workflow. Do not rely on this data module object's state update here as there is no - way to communicate the state update to other subprocesses. The + way to communicate the state update to other subprocesses. The nesting `pickles_to_tars` function goes through the data name prefixes in the different splits, read the corresponding pickled file and output a webdataset tar archive with the dict structure: {"__key__" : - name.replace(".", "-"), suffix_pickled : pickled.dumps(data) }. + name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. Returns: None """ for split in self._names_subset.keys(): # create wds shards (tar files) for train set pickles_to_tars( - self._dir_pickled, - self._suffix_pickled, + self._dir_pickles, + self._suffix_pickles, self._names_subset[split], self._dirs_tars_wds[split], self._prefix_tars_wds, From 19c3dcd8377aabc3a2acdaf617e66bad9fa81c7f Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 14 Aug 2024 23:29:39 +0000 Subject: [PATCH 44/70] Test: PickledDataWDS init and setup() --- .../tests/bionemo/core/data/conftest.py | 80 ++++++++++++++++++- .../bionemo/core/data/test_webdatamodule.py | 39 +++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index 50ce1e641f..f92dfbad71 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -22,14 +22,14 @@ import webdataset as wds from webdataset.filters import batched, shuffle -from bionemo.core.data.datamodule import WebDataModule, Split +from bionemo.core.data.datamodule import WebDataModule, Split, PickledDataWDS @pytest.fixture(scope="module") def get_path(request): path_test = Path(request.module.__file__).resolve() dir_test = path_test.parents[0] - dir_data = path_test.parents[6] / "test_data" / "bionemo-core" / "data" / "webdatamodule" + dir_data = path_test.parents[6] / "test_data" / "bionemo-core" / "data" return str(dir_test), str(dir_data) @@ -87,13 +87,15 @@ def _create_webdatamodule(dir_tars_wds): @pytest.fixture(scope="module") def create_webdatamodule(get_path): - _, dir_tars_wds = get_path + _, dir_data = get_path + dir_tars_wds = f"{dir_data}/webdatamodule" return _create_webdatamodule(dir_tars_wds) @pytest.fixture(scope="module") def create_another_webdatamodule(get_path): - _, dir_tars_wds = get_path + _, dir_data = get_path + dir_tars_wds = f"{dir_data}/webdatamodule" return _create_webdatamodule(dir_tars_wds) @@ -132,3 +134,73 @@ def create_trainer_and_model(): trainer = L.Trainer(max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1) model = ModelTestWebDataModule() return trainer, model + + +def _create_pickleddatawds(tmp_path_factory, dir_pickles): + suffix_keys_wds = "tensor.pyd" + local_batch_size = 2 + global_batch_size = 2 + prefix_tars_wds = "tensor" + seed_rng_shfl = 82838392 + n_tars_wds = 3 + + prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() + + names = { split : [ f"sample-{i:04d}" for i in + range(10) ] for split in Split + } + + n_samples = { split : 10 for split in Split } + + batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) + + untuple = lambda source : (sample for (sample,) in source) + + pipeline_wds = { + Split.train : [untuple, shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl))], + Split.val : untuple, + Split.test : untuple + } + + pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } + + kwargs_wds = { + split : {'shardshuffle' : split == Split.train, + 'nodesplitter' : wds.split_by_node, + 'seed' : seed_rng_shfl} + for split in Split + } + + kwargs_wld = { + split : {"num_workers": 2} for split in Split + } + + data_module = PickledDataWDS(dir_pickles, suffix_keys_wds, names, + prefix_dir_tars_wds, global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, kwargs_wld=kwargs_wld) + + return data_module, prefix_dir_tars_wds + +@pytest.fixture(scope="module") +def create_pickleddatawds(tmp_path_factory, get_path): + _, dir_data = get_path + dir_pickles = f"{dir_data}/pickleddatawds" + return _create_pickleddatawds(tmp_path_factory, dir_pickles) + + +@pytest.fixture(scope="module") +def create_another_pickleddatawds(tmp_path_factory, get_path): + _, dir_data = get_path + dir_pickles = f"{dir_data}/pickleddatawds" + return _create_pickleddatawds(tmp_path_factory, dir_pickles) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py index 8e20ce2f48..4beeacc9c0 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py @@ -134,3 +134,42 @@ def test_webdatamodule_in_lightning(stage, create_webdatamodule, workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) assert model._samples[split] == samples + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_init(split, create_pickleddatawds): + data_module, prefix_dir_tars_wds = create_pickleddatawds + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + name_split = str(split).split(".")[-1] + assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( + f"Wrong tar files directory: " + f"expected {prefix_dir_tars_wds}{name_split} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, + create_another_pickleddatawds): + data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors= [] + for sample in m._dataset[split]: + assert isinstance(sample, torch.Tensor),\ + "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close(torch.vstack(lists_tensors[0]), + torch.vstack(lists_tensors[1])) From 2e6e8db768969abfb02a43558b4fedcfdf812fcb Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 18:21:00 +0000 Subject: [PATCH 45/70] Doc: update class description --- sub-packages/bionemo-core/README.md | 8 ++++---- .../bionemo-core/src/bionemo/core/data/datamodule.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sub-packages/bionemo-core/README.md b/sub-packages/bionemo-core/README.md index 1157f39e9b..110d648032 100644 --- a/sub-packages/bionemo-core/README.md +++ b/sub-packages/bionemo-core/README.md @@ -11,10 +11,10 @@ class WebDataModule(L.LightningDataModule) ``` A LightningDataModule for using webdataset tar files to setup dataset and -dataloader. This data module takes a dictionary: Split -> tar file -directory. In its setup() function, it creates the webdataset object -chaining up the input `pipeline_wds` workflow. In its -train/val/test_dataloader(), it creates the WebLoader object chaining up the +dataloader. This data module takes as input a dictionary: Split -> tar file +directory and vaiours webdataset config settings. In its setup() function, it +creates the webdataset object chaining up the input `pipeline_wds` workflow. In +its train/val/test_dataloader(), it creates the WebLoader object chaining up the `pipeline_prebatch_wld` workflow Examples diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index b8b82fafef..5c1b6a4f6f 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -32,11 +32,11 @@ class Split(Enum): class WebDataModule(L.LightningDataModule): """A LightningDataModule for using webdataset tar files to setup dataset and - dataloader. This data module takes a dictionary: Split -> tar file - directory. In its setup() function, it creates the webdataset object - chaining up the input `pipeline_wds` workflow. In its - train/val/test_dataloader(), it creates the WebLoader object chaining up the - `pipeline_prebatch_wld` workflow + dataloader. This data module takes as input a dictionary: Split -> tar file + directory and vaiours webdataset config settings. In its setup() function, + it creates the webdataset object chaining up the input `pipeline_wds` + workflow. In its train/val/test_dataloader(), it creates the WebLoader + object chaining up the `pipeline_prebatch_wld` workflow Examples -------- From 9693aba1be59ba4705ea52f09636cfd8bafe7457 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 18:44:32 +0000 Subject: [PATCH 46/70] Test: re-org test data location --- sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py | 2 +- .../bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index f92dfbad71..484c9803e1 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -29,7 +29,7 @@ def get_path(request): path_test = Path(request.module.__file__).resolve() dir_test = path_test.parents[0] - dir_data = path_test.parents[6] / "test_data" / "bionemo-core" / "data" + dir_data = path_test.parents[6] / "test_data" / "datamodule" return str(dir_test), str(dir_data) diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 5e2fcb2cf0..6afc7aecab 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -35,7 +35,7 @@ def get_path(request): path_test = Path(request.module.__file__).resolve() dir_test = path_test.parents[0] dir_data = path_test.parents[6] / "test_data" / \ - "bionemo-diffdock" / "data" / "pyg_heterodata_pickled" + "diffdock" / "pyg_heterodata_pickled" return str(dir_test), str(dir_data) From 71391a8e565d150546a24ee80ab06a68ec12c223 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 20:20:24 +0000 Subject: [PATCH 47/70] Test: generate the test data for webdatamodule ... instead of using downloaded ones because the data size is too small to be worth maintaining --- .../tests/bionemo/core/data/conftest.py | 78 +++++++++++-------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index 484c9803e1..275285969e 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import pickle import random -from pathlib import Path import lightning as L import pytest @@ -23,26 +24,42 @@ from webdataset.filters import batched, shuffle from bionemo.core.data.datamodule import WebDataModule, Split, PickledDataWDS +from bionemo.core.data.utils import pickles_to_tars @pytest.fixture(scope="module") -def get_path(request): - path_test = Path(request.module.__file__).resolve() - dir_test = path_test.parents[0] - dir_data = path_test.parents[6] / "test_data" / "datamodule" - return str(dir_test), str(dir_data) - - -def _create_webdatamodule(dir_tars_wds): - suffix_keys_wds = "tensor.pyd" +def gen_test_data(tmp_path_factory): + dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() + dir_tars = tmp_path_factory.mktemp("webdatamodule").as_posix() + prefix_sample = "sample" + suffix_sample = "tensor.pyd" + prefix_tar = "tensor" + n_samples = 10 + os.makedirs(dir_pickles, exist_ok=True) + prefix_subset = [] + # generate the pickles + for i in range(n_samples): + prefix = f"{prefix_sample}-{i:04}" + prefix_subset.append(prefix) + t = torch.tensor(i, dtype=torch.int32) + pickle.dump(t, open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb")) + # generate the tars + pickles_to_tars(dir_pickles, suffix_sample, prefix_subset, dir_tars, + prefix_tar, min_num_shards=3) + return (dir_pickles, dir_tars, prefix_sample, suffix_sample, prefix_tar, + n_samples) + + +def _create_webdatamodule(gen_test_data): + (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, + n_samples_in_tar) = gen_test_data local_batch_size = 2 global_batch_size = 2 - prefix_tars_wds = "tensor" seed_rng_shfl = 82838392 dirs_tars_wds = { split : dir_tars_wds for split in Split } - n_samples = { split : 10 for split in Split } + n_samples = { split : n_samples_in_tar for split in Split } batch = batched(local_batch_size, collation_fn=lambda list_samples : torch.vstack(list_samples)) @@ -86,17 +103,13 @@ def _create_webdatamodule(dir_tars_wds): @pytest.fixture(scope="module") -def create_webdatamodule(get_path): - _, dir_data = get_path - dir_tars_wds = f"{dir_data}/webdatamodule" - return _create_webdatamodule(dir_tars_wds) +def create_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) @pytest.fixture(scope="module") -def create_another_webdatamodule(get_path): - _, dir_data = get_path - dir_tars_wds = f"{dir_data}/webdatamodule" - return _create_webdatamodule(dir_tars_wds) +def create_another_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) class ModelTestWebDataModule(L.LightningModule): @@ -136,21 +149,21 @@ def create_trainer_and_model(): return trainer, model -def _create_pickleddatawds(tmp_path_factory, dir_pickles): - suffix_keys_wds = "tensor.pyd" +def _create_pickleddatawds(tmp_path_factory, gen_test_data): + (dir_pickles, _, prefix_sample, suffix_keys_wds, prefix_tars_wds, + n_samples_in_tar) = gen_test_data local_batch_size = 2 global_batch_size = 2 - prefix_tars_wds = "tensor" seed_rng_shfl = 82838392 n_tars_wds = 3 prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() - names = { split : [ f"sample-{i:04d}" for i in - range(10) ] for split in Split + names = { split : [ f"{prefix_sample}-{i:04d}" for i in + range(n_samples_in_tar) ] for split in Split } - n_samples = { split : 10 for split in Split } + n_samples = { split : n_samples_in_tar for split in Split } batch = batched(local_batch_size, collation_fn=lambda list_samples : torch.vstack(list_samples)) @@ -192,15 +205,12 @@ def _create_pickleddatawds(tmp_path_factory, dir_pickles): return data_module, prefix_dir_tars_wds + @pytest.fixture(scope="module") -def create_pickleddatawds(tmp_path_factory, get_path): - _, dir_data = get_path - dir_pickles = f"{dir_data}/pickleddatawds" - return _create_pickleddatawds(tmp_path_factory, dir_pickles) +def create_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) @pytest.fixture(scope="module") -def create_another_pickleddatawds(tmp_path_factory, get_path): - _, dir_data = get_path - dir_pickles = f"{dir_data}/pickleddatawds" - return _create_pickleddatawds(tmp_path_factory, dir_pickles) +def create_another_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) From eff50934b3952bae47fd847e536380d82857c6e4 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 23:03:09 +0000 Subject: [PATCH 48/70] Enhancement: remove unused functions --- .../src/bionemo/diffdock/utils/diffusion.py | 74 ------------------- .../src/bionemo/diffdock/utils/torsion.py | 62 ---------------- 2 files changed, 136 deletions(-) diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py index 4c3e361ec1..cb91a67fd7 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py @@ -31,8 +31,6 @@ import numpy as np import torch -import torch.nn.functional as F -from torch import nn from torch_geometric.data.hetero_data import HeteroData from bionemo.diffdock.utils import so3, torus @@ -73,78 +71,6 @@ def modify_conformer(data, tr_update, rot_update, torsion_updates): return data -def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000): - """from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py""" - assert len(timesteps.shape) == 1 - half_dim = embedding_dim // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode="constant") - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels. - from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 - """ - - def __init__(self, embedding_size=256, scale=1.0): - super().__init__() - self.W = nn.Parameter(torch.randn(embedding_size // 2) * scale, requires_grad=False) - - def forward(self, x): - x_proj = x[:, None] * self.W[None, :] * 2 * np.pi - emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return emb - - -def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000): - if embedding_type == "sinusoidal": - - def emb_func(x): - return sinusoidal_embedding(embedding_scale * x, embedding_dim) - - elif embedding_type == "fourier": - emb_func = GaussianFourierProjection(embedding_size=embedding_dim, scale=embedding_scale) - else: - raise NotImplementedError - return emb_func - - -class timestep_embedding(nn.Module): - def __init__(self, embedding_type, embedding_dim, embedding_scale=10000): - super(timestep_embedding, self).__init__() - self.embedding_type = embedding_type - self.embedding_dim = embedding_dim - self.embedding_scale = embedding_scale - self.emb_func = get_timestep_embedding(embedding_type, embedding_dim, embedding_scale) - - def forward(self, *args, **kwargs): - return self.emb_func(*args, **kwargs) - - def __getstate__(self): - return { - "embedding_type": self.embedding_type, - "embedding_dim": self.embedding_dim, - "embedding_scale": self.embedding_scale, - } - - def __setstate__(self, d): - super(timestep_embedding, self).__init__() - self.embedding_type = d["embedding_type"] - self.embedding_dim = d["embedding_dim"] - self.embedding_scale = d["embedding_scale"] - self.emb_func = get_timestep_embedding(**d) - - -def get_t_schedule(denoising_inference_steps): - return np.linspace(1, 0, denoising_inference_steps + 1)[:-1] - - def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device): complex_graphs["ligand"].node_t = { "tr": t_tr * torch.ones(complex_graphs["ligand"].num_nodes).to(device), diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py index 003e028241..a927d938e1 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py @@ -26,12 +26,9 @@ import copy -import networkx as nx import numpy as np import torch from scipy.spatial.transform import Rotation as R -from torch_geometric.data import Data -from torch_geometric.utils import to_networkx """ @@ -39,39 +36,6 @@ """ -def get_transformation_mask(pyg_data): - G = to_networkx(pyg_data.to_homogeneous(), to_undirected=False) - to_rotate = [] - edges = pyg_data["ligand", "ligand"].edge_index.T.numpy() - for i in range(0, edges.shape[0], 2): - assert edges[i, 0] == edges[i + 1, 1] - - G2 = G.to_undirected() - G2.remove_edge(*edges[i]) - if not nx.is_connected(G2): - l = list(sorted(nx.connected_components(G2), key=len)[0]) - if len(l) > 1: - if edges[i, 0] in l: - to_rotate.append([]) - to_rotate.append(l) - else: - to_rotate.append(l) - to_rotate.append([]) - continue - to_rotate.append([]) - to_rotate.append([]) - - mask_edges = np.asarray([0 if len(l) == 0 else 1 for l in to_rotate], dtype=bool) - mask_rotate = np.zeros((np.sum(mask_edges), len(G.nodes())), dtype=bool) - idx = 0 - for i in range(len(G.edges())): - if mask_edges[i]: - mask_rotate[idx][np.asarray(to_rotate[i], dtype=int)] = True - idx += 1 - - return mask_edges, mask_rotate - - def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False): pos = copy.deepcopy(pos) if type(pos) != np.ndarray: @@ -97,29 +61,3 @@ def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_update return pos -def perturb_batch(data, torsion_updates, split=False, return_updates=False): - if type(data) is Data: - return modify_conformer_torsion_angles( - data.pos, data.edge_index.T[data.edge_mask], data.mask_rotate, torsion_updates - ) - pos_new = [] if split else copy.deepcopy(data.pos) - edges_of_interest = data.edge_index.T[data.edge_mask] - idx_node = 0 - idx_edges = 0 - torsion_update_list = [] - for i, mask_rotate in enumerate(data.mask_rotate): - pos = data.pos[idx_node : idx_node + mask_rotate.shape[1]] - edges = edges_of_interest[idx_edges : idx_edges + mask_rotate.shape[0]] - idx_node - torsion_update = torsion_updates[idx_edges : idx_edges + mask_rotate.shape[0]] - torsion_update_list.append(torsion_update) - pos_new_ = modify_conformer_torsion_angles(pos, edges, mask_rotate, torsion_update) - if split: - pos_new.append(pos_new_) - else: - pos_new[idx_node : idx_node + mask_rotate.shape[1]] = pos_new_ - - idx_node += mask_rotate.shape[1] - idx_edges += mask_rotate.shape[0] - if return_updates: - return pos_new, torsion_update_list - return pos_new From c8c396bc36e79abdfd9978e84ac9f07fd1e85cc1 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 23:07:30 +0000 Subject: [PATCH 49/70] BugFix: remove license --- sub-packages/bionemo-diffdock/LICENSE | 202 -------------------------- 1 file changed, 202 deletions(-) delete mode 100644 sub-packages/bionemo-diffdock/LICENSE diff --git a/sub-packages/bionemo-diffdock/LICENSE b/sub-packages/bionemo-diffdock/LICENSE deleted file mode 100644 index d645695673..0000000000 --- a/sub-packages/bionemo-diffdock/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. From 458611dbf32c2c96d51bed40fc92e10c84a62f25 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 15 Aug 2024 23:43:31 +0000 Subject: [PATCH 50/70] Enhancement: use list comprehension --- .../tests/bionemo/core/data/test_webdatamodule.py | 4 +--- .../tests/bionemo/diffdock/data/test_diffdock_datamodule.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py index 4beeacc9c0..d0ed0b260e 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py @@ -127,9 +127,7 @@ def test_webdatamodule_in_lightning(stage, create_webdatamodule, # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() - samples = [] - for sample in loader: - samples.append(sample.name) + samples = [ sample.name for sample in loader ] L.seed_everything(2823828) workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py index 4c52943634..7f1bfb94cd 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py @@ -186,9 +186,7 @@ def test_datamodule_in_lightning(stage, create_datamodule, create_another_datamo # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() - samples = [] - for sample in loader: - samples.append(sample.name) + samples = [ sample.name for sample in loader ] lightning.seed_everything(2823828) workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) From e6f9c3fa329bb339653a5f4ba21904e9e06f48f9 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 16 Aug 2024 00:09:14 +0000 Subject: [PATCH 51/70] License: clean up old licenses --- .../src/bionemo/core/data/__init__.py | 20 ++++++++----------- .../bionemo-core/src/bionemo/core/data/api.py | 20 ++++++++----------- .../src/bionemo/core/data/datamodule.py | 20 ++++++++----------- .../src/bionemo/core/data/utils.py | 20 ++++++++----------- .../tests/bionemo/core/data/conftest.py | 20 ++++++++----------- .../bionemo/core/data/test_webdatamodule.py | 20 ++++++++----------- .../src/bionemo/diffdock/utils/data.py | 20 ++++++++----------- .../src/bionemo/diffdock/utils/diffusion.py | 17 ---------------- .../src/bionemo/diffdock/utils/geometry.py | 16 --------------- .../src/bionemo/diffdock/utils/so3.py | 16 --------------- .../src/bionemo/diffdock/utils/torsion.py | 16 --------------- .../src/bionemo/diffdock/utils/torus.py | 16 --------------- .../tests/bionemo/diffdock/data/conftest.py | 20 ++++++++----------- .../diffdock/data/test_diffdock_datamodule.py | 20 ++++++++----------- 14 files changed, 72 insertions(+), 189 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py b/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py index 25e6abfbc5..f7d0fc1f13 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py @@ -1,14 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/api.py b/sub-packages/bionemo-core/src/bionemo/core/data/api.py index 6dae5df001..38ea91be64 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/api.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/api.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + from typing import Sequence diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py index 5c1b6a4f6f..26a703eba1 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + import glob import random diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/utils.py b/sub-packages/bionemo-core/src/bionemo/core/data/utils.py index ab4ed8d3cb..f8898c3e02 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/utils.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/utils.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + import os import pickle diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py index 275285969e..2e2f8c0896 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + import os import pickle diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py index d0ed0b260e..8285771f3f 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py +++ b/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + from enum import Enum, auto diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py index 13116253d4..da904e13f9 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + import math import random from typing import Any, Callable, Generator, Iterable, List, Union diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py index cb91a67fd7..31ca5228ca 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py @@ -1,19 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # @@ -24,7 +8,6 @@ # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. -import math import random from copy import deepcopy from typing import Callable, Generator, Tuple diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py index 547b17c1c7..648045d261 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py @@ -1,19 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py index 651828ba81..7fa0da0a3e 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py @@ -1,19 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py index a927d938e1..3134d26a7e 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py @@ -1,19 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py index 9a7966bbd4..977cff937d 100644 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py @@ -1,19 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py index 6afc7aecab..8dc9ab600f 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + from pathlib import Path from enum import Enum, auto diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py index 7f1bfb94cd..95ca4a88fe 100644 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py +++ b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py @@ -1,17 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + import glob from enum import Enum, auto From c470b21ad6d8b867f849f7775c5eb428a8ebcca6 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 16 Aug 2024 00:10:40 +0000 Subject: [PATCH 52/70] BugFix: add __init__.py --- .../bionemo-diffdock/src/bionemo/diffdock/__init__.py | 10 ++++++++++ .../src/bionemo/diffdock/utils/__init__.py | 10 ++++++++++ 2 files changed, 20 insertions(+) create mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py create mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py new file mode 100644 index 0000000000..f7d0fc1f13 --- /dev/null +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py new file mode 100644 index 0000000000..f7d0fc1f13 --- /dev/null +++ b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + From e1e55518c638bb962e0cc0dfd40c71a3cfac16e9 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 19 Aug 2024 22:09:54 +0000 Subject: [PATCH 53/70] BugFix: my email --- sub-packages/bionemo-diffdock/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-diffdock/pyproject.toml b/sub-packages/bionemo-diffdock/pyproject.toml index 29057f478c..cd4c7a70d9 100644 --- a/sub-packages/bionemo-diffdock/pyproject.toml +++ b/sub-packages/bionemo-diffdock/pyproject.toml @@ -11,7 +11,7 @@ description = "BioNeMo DiffDock" authors = [ { name = "John St. John", email = "jstjohn@nvidia.com" }, { name = "Malcolm Greaves", email = "mgreaves@nvidia.com" }, - { name = "Dejun Lin", email = "dejun.lin@gmail.com" }, + { name = "Dejun Lin", email = "dejunl@nvidia.com" }, ] dynamic = ["dependencies"] From 0f46ed8ef4a8865722d4264cd9a0959014e1d78d Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 19 Aug 2024 23:54:46 +0000 Subject: [PATCH 54/70] Package: move webdatamodule under bionemo-webdatamodule --- sub-packages/bionemo-core/README.md | 342 ----------------- .../src/bionemo/core/data/__init__.py | 20 +- .../bionemo-core/src/bionemo/core/data/api.py | 20 +- sub-packages/bionemo-webdatamodule/LICENSE | 202 ++++++++++ sub-packages/bionemo-webdatamodule/README.md | 354 ++++++++++++++++++ .../bionemo-webdatamodule/pyproject.toml | 29 ++ .../bionemo-webdatamodule/requirements.txt | 1 + .../src/bionemo/webdatamodule/__init__.py | 14 + .../src/bionemo/webdatamodule}/datamodule.py | 2 +- .../src/bionemo/webdatamodule}/utils.py | 0 .../tests/bionemo/webdatamodule/__init__.py | 0 .../tests/bionemo/webdatamodule}/conftest.py | 4 +- .../webdatamodule}/test_webdatamodule.py | 2 +- 13 files changed, 628 insertions(+), 362 deletions(-) create mode 100644 sub-packages/bionemo-webdatamodule/LICENSE create mode 100644 sub-packages/bionemo-webdatamodule/README.md create mode 100644 sub-packages/bionemo-webdatamodule/pyproject.toml create mode 100644 sub-packages/bionemo-webdatamodule/requirements.txt create mode 100644 sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py rename sub-packages/{bionemo-core/src/bionemo/core/data => bionemo-webdatamodule/src/bionemo/webdatamodule}/datamodule.py (99%) rename sub-packages/{bionemo-core/src/bionemo/core/data => bionemo-webdatamodule/src/bionemo/webdatamodule}/utils.py (100%) create mode 100644 sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py rename sub-packages/{bionemo-core/tests/bionemo/core/data => bionemo-webdatamodule/tests/bionemo/webdatamodule}/conftest.py (98%) rename sub-packages/{bionemo-core/tests/bionemo/core/data => bionemo-webdatamodule/tests/bionemo/webdatamodule}/test_webdatamodule.py (99%) diff --git a/sub-packages/bionemo-core/README.md b/sub-packages/bionemo-core/README.md index 110d648032..ad93dbee9e 100644 --- a/sub-packages/bionemo-core/README.md +++ b/sub-packages/bionemo-core/README.md @@ -4,345 +4,3 @@ ```bash pip install -e . ``` -## WebDataModule - -```python -class WebDataModule(L.LightningDataModule) -``` - -A LightningDataModule for using webdataset tar files to setup dataset and -dataloader. This data module takes as input a dictionary: Split -> tar file -directory and vaiours webdataset config settings. In its setup() function, it -creates the webdataset object chaining up the input `pipeline_wds` workflow. In -its train/val/test_dataloader(), it creates the WebLoader object chaining up the -`pipeline_prebatch_wld` workflow - -Examples --------- - -1. create the data module with input directory to webdataset tar files. -Depending on which of the downstream Lightning.Trainer methods are called, -e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or -`Trainer.predict()`, only a subset of the train, val and test splits need to -be specified in the various input options to the data module: - -- `Trainer.fit()` requires the `train` and `val` splits -- `Trainer.validate()` requires the `val` split -- `Trainer.test()` requires the `test` splits -- `Trainer.predict()` requires the `test` splits - -Here is an example of constructing the data module for `Trainer.fit()`: -``` ->>> from bionemo.core.data.datamodule import Split, WebDataModule ->>> ->>> tar_file_prefix = "shards" ->>> ->>> dirs_of_tar_files = { ->>> Split.train: "/path/to/train/split/tars", ->>> Split.val: "/path/to/val/split/tars", ->>> } ->>> ->>> n_samples { ->>> Split.train: 1000, ->>> Split.val: 100, ->>> } ->>> ->>> # this is the string to retrieve the corresponding data object from the ->>> # webdataset file (see ->>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format ->>> # for details) ->>> suffix_keys_wds = "tensor.pyd" ->>> ->>> # see the API doc for the definition of global_batch_size ->>> global_batch_size = 16 ->>> ->>> seed = 27193781 ->>> ->>> # Specify the routines to process the samples in the WebDataset object. ->>> # The routine is a generator of an Iterable of generators that are chained ->>> # together by nested function calling. The following is equivalent of ->>> # defining a overall generator of `shuffle(untuple(...))` which ->>> # untuples the samples and shuffles them. See webdataset's Documentation ->>> # for details. ->>> # NOTE: the `untuple` is almost always necessary due to the webdataset's ->>> # file parsing rule. ->>> ->>> untuple = lambda source : (sample for (sample,) in source) ->>> ->>> from webdatast import shuffle ->>> pipeline_wds = { ->>> Split.train : [untuple, shuffle(n_samples[Split.train], ->>> rng=random.Random(seed_rng_shfl))], ->>> Split.val: untuple ->>> } ->>> ->>> # Similarly the user can optionally define the processing routine on the ->>> # WebLoader (the dataloader of webdataset). ->>> # NOTE: these routines by default take unbatched sample as input so the ->>> # user can customize their batching routines here ->>> ->>> batch = batched(local_batch_size, collation_fn=lambda - list_samples : torch.vstack(list_samples)) ->>> pipeline_prebatch_wld = { - Split.train: [shuffle(n_samples[Split.train], - rng=random.Random(seed_rng_shfl)), batch], - Split.val : batch, - Split.test : batch - } ->>> ->>> # the user can optionally specify the kwargs for WebDataset and ->>> # WebLoader ->>> ->>> kwargs_wds = { ->>> split : {'shardshuffle' : split == Split.train, ->>> 'nodesplitter' : wds.split_by_node, ->>> 'seed' : seed_rng_shfl} ->>> for split in Split ->>> } ->>> ->>> kwargs_wld = { ->>> split : {"num_workers": 2} for split in Split ->>> } ->>> ->>> # construct the data module ->>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, - global_batch_size, - prefix_tars_wds=tar_file_prefix, - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipeline_prebatch_wld, - kwargs_wds=kwargs_wds, - kwargs_wld=kwargs_wld) -``` - - - -#### \_\_init\_\_ - -```python -def __init__( - dirs_tars_wds: Dict[Split, str], - n_samples: Dict[Split, int], - suffix_keys_wds: Union[str, Iterable[str]], - global_batch_size: int, - prefix_tars_wds: str = "wdshards", - pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], - Iterable[Any]]]] = None, - pipeline_prebatch_wld: Optional[Dict[Split, - Union[Iterable[Iterable[Any]], - Iterable[Any]]]] = None, - kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, - kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None) -``` - -constructor - -**Arguments**: - -- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file - directory that contains the webdataset tar files for each split -- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of - data samples for each split -- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each - corresponding to a data object in the webdataset tar file - dictionary. The data objects of these keys will be extracted and - tupled for each sample in the tar files -- `global_batch_size` _int_ - size of batch summing across nodes in Data - Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: - this data module doesn't rely on the input `global_batch_size` - for batching the samples. The batching is supposed to be done as - a part of the input `pipeline_prebatch_wld`. `global_batch_size` - is only used to compute a (pseudo-) epoch length for the data - loader so that the loader yield approximately n_samples // - global_batch_size batches - Kwargs: -- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar - files. The input tar files are globbed by - "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" - pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], -- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e., - functor that maps a iterator to another iterator that - transforms the data sample yield from the dataset object, for - different splits, or an iterable to such a sequence of such - iterators. For example, this can be used to transform the - sample in the worker before sending it to the main process of - the dataloader - pipeline_prebatch_wld (Optional[Dict[Split, - Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary - of webloader composable, i.e., functor that maps a iterator to - another iterator that transforms the data sample yield from the - WebLoader object, for different splits, or an iterable to a - seuqnence of such iterators. For example, this can be used for - batching the samples. NOTE: this is applied before batching is - yield from the WebLoader -- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the - WebDataset.__init__() -- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the - WebLoader.__init__(), e.g., num_workers, of each split - - - -#### prepare\_data - -```python -def prepare_data() -> None -``` - -This is called only by the main process by the Lightning workflow. Do -not rely on this data module object's state update here as there is no -way to communicate the state update to other subprocesses. - -Returns: None - - - -#### setup - -```python -def setup(stage: str) -> None -``` - -This is called on all Lightning-managed nodes in a multi-node -training session - - -**Arguments**: - -- `stage` _str_ - "fit", "test" or "predict" -- `Returns` - None - -## PickledDataWDS - -```python -class PickledDataWDS(WebDataModule) -``` - -A LightningDataModule to process pickled data into webdataset tar files -and setup dataset and dataloader. This inherits the webdataset setup from -its parent module `WebDataModule`. This data module takes a directory of -pickled data files, data filename prefixes for train/val/test splits, data -filename suffixes and prepare webdataset tar files by globbing the specific -pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and -outputing to webdataset tar file with the dict structure: -``` - {"__key__" : name.replace(".", "-"), - suffix_pickles : pickled.dumps(data) } -``` -NOTE: this assumes only one pickled file is processed for each sample. In -its setup() function, it creates the webdataset object chaining up the input -`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the -WebLoader object chaining up the `pipeline_prebatch_wld` workflow. - -Examples --------- - -1. create the data module with a directory of pickle files and the file name -prefix thereof for different splits to used by `Lightning.Trainer.fit()` - -``` ->>> from bionemo.core.data.datamodule import Split, PickledDataWDS - ->>> dir_pickles = "/path/to/my/pickles/dir" - ->>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the ->>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the ->>> # validation dataset - ->>> suffix_pickles = "mydata.pt" - ->>> names_subset = { ->>> Split.train: [sample1, sample2], ->>> Split.val: [sample4, sample5], ->>> } - ->>> # the following setting will attempt to create at least 5 tar files in ->>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` - ->>> n_tars_wds = 5 ->>> prefix_tars_wds = "myshards" ->>> output_dir_tar_files = "/path/to/output/tars/dir" - ->>> # see the `WebDataModule` API doc for the definition of global_batch_size ->>> global_batch_size = 16 - ->>> # user can optionally customize the data processing routines and kwargs used ->>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) - ->>> pipeline_wds = { Split.train: ... } - ->>> pipeline_prebatch_wld = { Split.train: ... } - ->>> kwargs_wds = { Split.train: ..., Split.val: ... } - ->>> kwargs_wld = { Split.train: ..., Split.val: ... } - ->>> # create the data module ->>> data_module = PickledDataWDS( ->>> dir_pickles, ->>> suffix_pickles, ->>> names_subset, ->>> output_dir_tar_files, ->>> global_batch_size, # `WebDataModule` args ->>> n_tars_wds=n_tars_wds, ->>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs ->>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs ->>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs ->>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs ->>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs ->>> ) - -``` - - - -#### \_\_init\_\_ - -```python -def __init__(dir_pickles: str, - suffix_pickles: str, - names_subset: Dict[Split, List[str]], - prefix_dir_tars_wds: str, - *args, - n_tars_wds: Optional[int] = None, - **kwargs) -``` - -constructor - -**Arguments**: - -- `dir_pickles` _str_ - input directory of pickled data files -- `suffix_pickles` _str_ - filename suffix of the input data in - dir_pickles. This is also used as the key mapped to the - tarballed pickled object in the webdataset -- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of - the data samples to be loaded in the dataset and dataloader for - each of the split -- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output - webdataset tar files. The actual directories storing the train, val - and test sets will be suffixed with "train", "val" and "test" - respectively. -- `*args` - arguments passed to the parent WebDataModule - - Kwargs: -- `n_tars_wds` _int_ - attempt to create at least this number of - webdataset shards -- `**kwargs` - arguments passed to the parent WebDataModule - - - -#### prepare\_data - -```python -def prepare_data() -> None -``` - -This is called only by the main process by the Lightning workflow. Do -not rely on this data module object's state update here as there is no -way to communicate the state update to other subprocesses. The nesting -`pickles_to_tars` function goes through the data name prefixes in the -different splits, read the corresponding pickled file and output a -webdataset tar archive with the dict structure: {"__key__" : -name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. - -Returns: None - diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py b/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py index f7d0fc1f13..25e6abfbc5 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/__init__.py @@ -1,10 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/api.py b/sub-packages/bionemo-core/src/bionemo/core/data/api.py index 38ea91be64..6dae5df001 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/api.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/api.py @@ -1,13 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Sequence diff --git a/sub-packages/bionemo-webdatamodule/LICENSE b/sub-packages/bionemo-webdatamodule/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sub-packages/bionemo-webdatamodule/README.md b/sub-packages/bionemo-webdatamodule/README.md new file mode 100644 index 0000000000..2c0cdff631 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/README.md @@ -0,0 +1,354 @@ +# bionemo-webdatamodule + +To install, execute the following: +```bash +pip install -e . +``` + +To run unit tests, execute: +```bash +pytest -v . +``` + +## WebDataModule + +```python +class WebDataModule(L.LightningDataModule) +``` + +A LightningDataModule for using webdataset tar files to setup dataset and +dataloader. This data module takes as input a dictionary: Split -> tar file +directory and vaiours webdataset config settings. In its setup() function, it +creates the webdataset object chaining up the input `pipeline_wds` workflow. In +its train/val/test_dataloader(), it creates the WebLoader object chaining up the +`pipeline_prebatch_wld` workflow + +Examples +-------- + +1. create the data module with input directory to webdataset tar files. +Depending on which of the downstream Lightning.Trainer methods are called, +e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or +`Trainer.predict()`, only a subset of the train, val and test splits need to +be specified in the various input options to the data module: + +- `Trainer.fit()` requires the `train` and `val` splits +- `Trainer.validate()` requires the `val` split +- `Trainer.test()` requires the `test` splits +- `Trainer.predict()` requires the `test` splits + +Here is an example of constructing the data module for `Trainer.fit()`: +``` +>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule +>>> +>>> tar_file_prefix = "shards" +>>> +>>> dirs_of_tar_files = { +>>> Split.train: "/path/to/train/split/tars", +>>> Split.val: "/path/to/val/split/tars", +>>> } +>>> +>>> n_samples { +>>> Split.train: 1000, +>>> Split.val: 100, +>>> } +>>> +>>> # this is the string to retrieve the corresponding data object from the +>>> # webdataset file (see +>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format +>>> # for details) +>>> suffix_keys_wds = "tensor.pyd" +>>> +>>> # see the API doc for the definition of global_batch_size +>>> global_batch_size = 16 +>>> +>>> seed = 27193781 +>>> +>>> # Specify the routines to process the samples in the WebDataset object. +>>> # The routine is a generator of an Iterable of generators that are chained +>>> # together by nested function calling. The following is equivalent of +>>> # defining a overall generator of `shuffle(untuple(...))` which +>>> # untuples the samples and shuffles them. See webdataset's Documentation +>>> # for details. +>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's +>>> # file parsing rule. +>>> +>>> untuple = lambda source : (sample for (sample,) in source) +>>> +>>> from webdatast import shuffle +>>> pipeline_wds = { +>>> Split.train : [untuple, shuffle(n_samples[Split.train], +>>> rng=random.Random(seed_rng_shfl))], +>>> Split.val: untuple +>>> } +>>> +>>> # Similarly the user can optionally define the processing routine on the +>>> # WebLoader (the dataloader of webdataset). +>>> # NOTE: these routines by default take unbatched sample as input so the +>>> # user can customize their batching routines here +>>> +>>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) +>>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } +>>> +>>> # the user can optionally specify the kwargs for WebDataset and +>>> # WebLoader +>>> +>>> kwargs_wds = { +>>> split : {'shardshuffle' : split == Split.train, +>>> 'nodesplitter' : wds.split_by_node, +>>> 'seed' : seed_rng_shfl} +>>> for split in Split +>>> } +>>> +>>> kwargs_wld = { +>>> split : {"num_workers": 2} for split in Split +>>> } +>>> +>>> # construct the data module +>>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) +``` + + + +#### \_\_init\_\_ + +```python +def __init__( + dirs_tars_wds: Dict[Split, str], + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + pipeline_prebatch_wld: Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None) +``` + +constructor + +**Arguments**: + +- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split +- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of + data samples for each split +- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files +- `global_batch_size` _int_ - size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: +- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], +- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader +- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebDataset.__init__() +- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. + +Returns: None + + + +#### setup + +```python +def setup(stage: str) -> None +``` + +This is called on all Lightning-managed nodes in a multi-node +training session + + +**Arguments**: + +- `stage` _str_ - "fit", "test" or "predict" +- `Returns` - None + +## PickledDataWDS + +```python +class PickledDataWDS(WebDataModule) +``` + +A LightningDataModule to process pickled data into webdataset tar files +and setup dataset and dataloader. This inherits the webdataset setup from +its parent module `WebDataModule`. This data module takes a directory of +pickled data files, data filename prefixes for train/val/test splits, data +filename suffixes and prepare webdataset tar files by globbing the specific +pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and +outputing to webdataset tar file with the dict structure: +``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } +``` +NOTE: this assumes only one pickled file is processed for each sample. In +its setup() function, it creates the webdataset object chaining up the input +`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the +WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + +Examples +-------- + +1. create the data module with a directory of pickle files and the file name +prefix thereof for different splits to used by `Lightning.Trainer.fit()` + +``` +>>> from bionemo.webdatamodule.datamodule import Split, PickledDataWDS + +>>> dir_pickles = "/path/to/my/pickles/dir" + +>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the +>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the +>>> # validation dataset + +>>> suffix_pickles = "mydata.pt" + +>>> names_subset = { +>>> Split.train: [sample1, sample2], +>>> Split.val: [sample4, sample5], +>>> } + +>>> # the following setting will attempt to create at least 5 tar files in +>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + +>>> n_tars_wds = 5 +>>> prefix_tars_wds = "myshards" +>>> output_dir_tar_files = "/path/to/output/tars/dir" + +>>> # see the `WebDataModule` API doc for the definition of global_batch_size +>>> global_batch_size = 16 + +>>> # user can optionally customize the data processing routines and kwargs used +>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + +>>> pipeline_wds = { Split.train: ... } + +>>> pipeline_prebatch_wld = { Split.train: ... } + +>>> kwargs_wds = { Split.train: ..., Split.val: ... } + +>>> kwargs_wld = { Split.train: ..., Split.val: ... } + +>>> # create the data module +>>> data_module = PickledDataWDS( +>>> dir_pickles, +>>> suffix_pickles, +>>> names_subset, +>>> output_dir_tar_files, +>>> global_batch_size, # `WebDataModule` args +>>> n_tars_wds=n_tars_wds, +>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs +>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs +>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs +>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs +>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs +>>> ) + +``` + + + +#### \_\_init\_\_ + +```python +def __init__(dir_pickles: str, + suffix_pickles: str, + names_subset: Dict[Split, List[str]], + prefix_dir_tars_wds: str, + *args, + n_tars_wds: Optional[int] = None, + **kwargs) +``` + +constructor + +**Arguments**: + +- `dir_pickles` _str_ - input directory of pickled data files +- `suffix_pickles` _str_ - filename suffix of the input data in + dir_pickles. This is also used as the key mapped to the + tarballed pickled object in the webdataset +- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split +- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. +- `*args` - arguments passed to the parent WebDataModule + + Kwargs: +- `n_tars_wds` _int_ - attempt to create at least this number of + webdataset shards +- `**kwargs` - arguments passed to the parent WebDataModule + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. The nesting +`pickles_to_tars` function goes through the data name prefixes in the +different splits, read the corresponding pickled file and output a +webdataset tar archive with the dict structure: {"__key__" : +name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + +Returns: None + diff --git a/sub-packages/bionemo-webdatamodule/pyproject.toml b/sub-packages/bionemo-webdatamodule/pyproject.toml new file mode 100644 index 0000000000..1f3be3632d --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +# For guidance, see: https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ +[project] +name = "bionemo-webdatamodule" +version = "0.0.1" +authors = [ + { name = "Dejun Lin", email = "dejunl@nvidia.com" }, +] +description = "" +readme = "README.md" +requires-python = ">=3.10" +keywords = [] +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3.10", + "Private :: Do Not Upload", +] +dynamic = ["dependencies"] + +[project.optional-dependencies] +test = [ + "pytest", +] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} diff --git a/sub-packages/bionemo-webdatamodule/requirements.txt b/sub-packages/bionemo-webdatamodule/requirements.txt new file mode 100644 index 0000000000..24ef528b0d --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/requirements.txt @@ -0,0 +1 @@ +webdataset==0.2.96 diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py similarity index 99% rename from sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py rename to sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py index 26a703eba1..b2af47652d 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/datamodule.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -17,7 +17,7 @@ import lightning as L import webdataset as wds -from bionemo.core.data.utils import pickles_to_tars +from bionemo.webdatamodule.utils import pickles_to_tars class Split(Enum): diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py similarity index 100% rename from sub-packages/bionemo-core/src/bionemo/core/data/utils.py rename to sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py similarity index 98% rename from sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py rename to sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 2e2f8c0896..1045b2b999 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -19,8 +19,8 @@ import webdataset as wds from webdataset.filters import batched, shuffle -from bionemo.core.data.datamodule import WebDataModule, Split, PickledDataWDS -from bionemo.core.data.utils import pickles_to_tars +from bionemo.webdatamodule.datamodule import WebDataModule, Split, PickledDataWDS +from bionemo.webdatamodule.utils import pickles_to_tars @pytest.fixture(scope="module") diff --git a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py similarity index 99% rename from sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py rename to sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index 8285771f3f..0093810808 100644 --- a/sub-packages/bionemo-core/tests/bionemo/core/data/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -16,7 +16,7 @@ import torch import lightning as L -from bionemo.core.data.datamodule import Split +from bionemo.webdatamodule.datamodule import Split @pytest.mark.parametrize("split", list(Split)) From b8f2bf59ed4337ce8cf8e147bbe53faaa9df7e31 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Mon, 19 Aug 2024 23:59:13 +0000 Subject: [PATCH 55/70] Package: remove bionemo-diffdock --- sub-packages/bionemo-diffdock/README.md | 6 - .../bionemo-diffdock/_requirements-test.txt | 1 - .../bionemo-diffdock/_requirements.txt | 1 - sub-packages/bionemo-diffdock/pyproject.toml | 26 -- .../bionemo-diffdock/requirements.txt | 2 - .../src/bionemo/diffdock/__init__.py | 10 - .../src/bionemo/diffdock/utils/__init__.py | 10 - .../src/bionemo/diffdock/utils/data.py | 221 ---------------- .../src/bionemo/diffdock/utils/diffusion.py | 137 ---------- .../src/bionemo/diffdock/utils/geometry.py | 125 --------- .../src/bionemo/diffdock/utils/so3.py | 173 ------------- .../src/bionemo/diffdock/utils/torsion.py | 47 ---- .../src/bionemo/diffdock/utils/torus.py | 103 -------- .../tests/bionemo/diffdock/data/conftest.py | 245 ------------------ .../diffdock/data/test_diffdock_datamodule.py | 189 -------------- 15 files changed, 1296 deletions(-) delete mode 100644 sub-packages/bionemo-diffdock/README.md delete mode 100644 sub-packages/bionemo-diffdock/_requirements-test.txt delete mode 100644 sub-packages/bionemo-diffdock/_requirements.txt delete mode 100644 sub-packages/bionemo-diffdock/pyproject.toml delete mode 100644 sub-packages/bionemo-diffdock/requirements.txt delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py delete mode 100644 sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py delete mode 100644 sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py delete mode 100644 sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py diff --git a/sub-packages/bionemo-diffdock/README.md b/sub-packages/bionemo-diffdock/README.md deleted file mode 100644 index 938640398b..0000000000 --- a/sub-packages/bionemo-diffdock/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# bionemo-diffdock - - -```bash -pip install -e . -``` diff --git a/sub-packages/bionemo-diffdock/_requirements-test.txt b/sub-packages/bionemo-diffdock/_requirements-test.txt deleted file mode 100644 index 47d98580a4..0000000000 --- a/sub-packages/bionemo-diffdock/_requirements-test.txt +++ /dev/null @@ -1 +0,0 @@ --e ../bionemo-testing diff --git a/sub-packages/bionemo-diffdock/_requirements.txt b/sub-packages/bionemo-diffdock/_requirements.txt deleted file mode 100644 index 22c9b13a61..0000000000 --- a/sub-packages/bionemo-diffdock/_requirements.txt +++ /dev/null @@ -1 +0,0 @@ --e ../bionemo-core diff --git a/sub-packages/bionemo-diffdock/pyproject.toml b/sub-packages/bionemo-diffdock/pyproject.toml deleted file mode 100644 index cd4c7a70d9..0000000000 --- a/sub-packages/bionemo-diffdock/pyproject.toml +++ /dev/null @@ -1,26 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "bionemo-diffdock" -version = "2.0.0" -license.file = "LICENSE" -readme = "README.md" -description = "BioNeMo DiffDock" -authors = [ - { name = "John St. John", email = "jstjohn@nvidia.com" }, - { name = "Malcolm Greaves", email = "mgreaves@nvidia.com" }, - { name = "Dejun Lin", email = "dejunl@nvidia.com" }, -] -dynamic = ["dependencies"] - -[tool.setuptools.dynamic] -dependencies = {file = ["requirements.txt"]} -# TODO: how to specify bionemo-{feature packages} & bionemo-core ??? - -[tool.setuptools.packages.find] -where = ["src"] -include=["bionemo.*"] -namespaces = true -exclude = ["test*."] diff --git a/sub-packages/bionemo-diffdock/requirements.txt b/sub-packages/bionemo-diffdock/requirements.txt deleted file mode 100644 index ec114b4f89..0000000000 --- a/sub-packages/bionemo-diffdock/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.26.4 -scipy==1.12.0 diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py deleted file mode 100644 index f7d0fc1f13..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py deleted file mode 100644 index f7d0fc1f13..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py deleted file mode 100644 index da904e13f9..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/data.py +++ /dev/null @@ -1,221 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import math -import random -from typing import Any, Callable, Generator, Iterable, List, Union - -import numpy as np -import torch -from nemo.utils import logging -from omegaconf.listconfig import ListConfig -from torch_geometric.data import HeteroData -from torch_geometric.data.batch import Batch -from torch_geometric.loader.dataloader import Collater - - -def num_cross_edge_upper_bound_estimate(n1, n2, n3, n4): - terms = [[4.92, "ligand_ligand"], [0.0118, "receptor_receptor"], [0.0401, "ligand", "receptor_receptor"]] - scale = 1.03 - tmpdict = {"ligand": n1, "ligand_ligand": n2, "receptor": n3, "receptor_receptor": n4} - num_edges = 0.0 - for term in terms: - tmp = term[0] - for k in term[1:]: - tmp *= tmpdict[k] - num_edges += tmp - num_edges *= scale - return num_edges - - -def estimate_memory_usage(data, num_cross_edges, use_bias=True): - # bias is from the memory of model, so when estimate the upper bound for size aware batch sampler, we don't need this - coeff_ligand_num_nodes = 2.9 - coeff_ligand_num_edges = 0.0 - coeff_receptor_num_nodes = 0.0 - coeff_receptor_num_edges = 0.11 - coeff_num_cross_edges = 0.25 - total_memory = ( - coeff_ligand_num_nodes * data["ligand"].num_nodes - + coeff_ligand_num_edges * data["ligand", "ligand"].num_edges - + coeff_receptor_num_nodes * data["receptor"].num_nodes - + coeff_receptor_num_edges * data["receptor", "receptor"].num_edges - + coeff_num_cross_edges * num_cross_edges - ) - if use_bias: - bias = 430.5 - return total_memory + bias - else: - return total_memory - - -def estimate_size(g): - n1, n2, n3, n4 = ( - g["ligand"].num_nodes, - g["ligand", "ligand"].num_edges, - g["receptor"].num_nodes, - g["receptor", "receptor"].num_edges, - ) - # estimate the upper bound of the number of cross edges - # the number of cross edges roughly increases w.r.t. the diffusion step t (sampled from uniform(0,1)) - # the empirical formula here is from the polynomial fitting - # the scaling constant is to help remove the outliers above the upper bound estimation. - n5 = num_cross_edge_upper_bound_estimate(n1, n2, n3, n4) - total_memory = estimate_memory_usage(g, n5, use_bias=False) - return total_memory - - -class SizeAwareBatching: - """A WebDataset composable to do batching based on sample size""" - - def __init__( - self, - max_total_size: int, - size_fn: Callable[[HeteroData], int], - collate_fn: Callable[[List[Any]], Any] = Collater(dataset=[], follow_batch=None, exclude_keys=None), - no_single_sample: bool = True, - ): - self.max_total_size = max_total_size - self.size_fn = size_fn - self.collate_fn = collate_fn - self.cached_sizes = {} - self.no_single_sample = no_single_sample - - def __call__(self, data: Batch) -> Generator[Union[Batch, List[HeteroData]], None, None]: - batch_size = 0 - batch = [] - - for sample in data: - if sample.name not in self.cached_sizes: - self.cached_sizes[sample.name] = self.size_fn(sample) - sample_size = self.cached_sizes[sample.name] - if sample_size > self.max_total_size: - logging.warning(f"sample {sample.name} has size larger than max size {self.max_total_size}, skipping") - continue - if (batch_size + sample_size) <= self.max_total_size: - batch.append(sample) - batch_size += sample_size - else: - if self.no_single_sample and len(batch) <= 1: - # memory size requirement is met but there is less than 2 - # samples in the batch so skip - batch = [sample] - batch_size = sample_size - continue - if self.collate_fn is not None: - batch = self.collate_fn(batch) - yield batch - - batch = [sample] - batch_size = sample_size - - -class SelectPoseAndLabelData: - """A WebDataset composable to select one ligand poses from multiple ones and - label confidence model training data by RMSD threshold""" - - def __init__( - self, - rmsd_classification_cutoff: Union[float, ListConfig], - samples_per_complex: int, - balance: bool, - all_atoms: bool, - seed: int = 0, - ): - """constructor - - Args: - rmsd_classification_cutoff (Union[float, ListConfig]): RMSD classification cutoff(s) - samples_per_complex (int): how many inference runs were done per complex - balance (bool): whether to do balance sampling - all_atoms (bool): whether the confidence model is all-atom - seed (int): random number generator seed - - Returns: - - """ - self.rmsd_classification_cutoff = rmsd_classification_cutoff - self.samples_per_complex = samples_per_complex - self.balance = balance - self.all_atoms = all_atoms - self._seed = seed - - def __call__(self, data: Iterable) -> Generator[HeteroData, None, None]: - """Map the input data iterator to another one that label the input data - - Args: - data (Iterable): Input data iterator - - Returns: - - """ - random.seed(self._seed) - for (complex_graph,) in data: - positions, rmsds = complex_graph.ligand_data - - if self.balance: - if isinstance(self.rmsd_classification_cutoff, ListConfig): - raise ValueError("a list for rmsd_classification_cutoff can only be used with balance=False") - # FIXME: should allow random.seed - label = random.randint(0, 1) - success = rmsds < self.rmsd_classification_cutoff - n_success = np.count_nonzero(success) - if label == 0 and n_success != self.samples_per_complex: - # sample negative complex - sample = random.randint(0, self.samples_per_complex - n_success - 1) - lig_pos = positions[~success][sample] - complex_graph["ligand"].pos = torch.from_numpy(lig_pos) - else: - # sample positive complex - if n_success > 0: # if no successful sample returns the matched complex - sample = random.randint(0, n_success - 1) - lig_pos = positions[success][sample] - complex_graph["ligand"].pos = torch.from_numpy(lig_pos) - complex_graph.y = torch.tensor(label).float() - else: - sample = random.randint(0, self.samples_per_complex - 1) - complex_graph["ligand"].pos = torch.from_numpy(positions[sample]) - ids = (rmsds[sample] < self.rmsd_classification_cutoff).astype(int) - complex_graph.y = torch.tensor(ids).float().unsqueeze(0) - if isinstance(self.rmsd_classification_cutoff, ListConfig): - complex_graph.y_binned = torch.tensor( - np.logical_and( - rmsds[sample] < self.rmsd_classification_cutoff + [math.inf], - rmsds[sample] >= [0] + self.rmsd_classification_cutoff, - ), - dtype=torch.float, - ).unsqueeze(0) - complex_graph.y = ( - torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff[0]).unsqueeze(0).float() - ) - complex_graph.rmsd = torch.tensor(rmsds[sample]).unsqueeze(0).float() - - complex_graph["ligand"].node_t = { - "tr": 0 * torch.ones(complex_graph["ligand"].num_nodes), - "rot": 0 * torch.ones(complex_graph["ligand"].num_nodes), - "tor": 0 * torch.ones(complex_graph["ligand"].num_nodes), - } - complex_graph["receptor"].node_t = { - "tr": 0 * torch.ones(complex_graph["receptor"].num_nodes), - "rot": 0 * torch.ones(complex_graph["receptor"].num_nodes), - "tor": 0 * torch.ones(complex_graph["receptor"].num_nodes), - } - if self.all_atoms: - complex_graph["atom"].node_t = { - "tr": 0 * torch.ones(complex_graph["atom"].num_nodes), - "rot": 0 * torch.ones(complex_graph["atom"].num_nodes), - "tor": 0 * torch.ones(complex_graph["atom"].num_nodes), - } - complex_graph.complex_t = { - "tr": 0 * torch.ones(1), - "rot": 0 * torch.ones(1), - "tor": 0 * torch.ones(1), - } - yield complex_graph diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py deleted file mode 100644 index 31ca5228ca..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/diffusion.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import random -from copy import deepcopy -from typing import Callable, Generator, Tuple - -import numpy as np -import torch -from torch_geometric.data.hetero_data import HeteroData - -from bionemo.diffdock.utils import so3, torus -from bionemo.diffdock.utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch -from bionemo.diffdock.utils.torsion import modify_conformer_torsion_angles - - -def t_to_sigma( - tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max, t_tr, t_rot, t_tor -): - tr_sigma = tr_sigma_min ** (1 - t_tr) * tr_sigma_max**t_tr - rot_sigma = rot_sigma_min ** (1 - t_rot) * rot_sigma_max**t_rot - tor_sigma = tor_sigma_min ** (1 - t_tor) * tor_sigma_max**t_tor - return tr_sigma, rot_sigma, tor_sigma - - -def modify_conformer(data, tr_update, rot_update, torsion_updates): - lig_center = torch.mean(data["ligand"].pos, dim=0, keepdim=True) - rot_mat = axis_angle_to_matrix(rot_update.squeeze()) - rigid_new_pos = (data["ligand"].pos - lig_center) @ rot_mat.T + tr_update + lig_center - - if torsion_updates is not None: - flexible_new_pos = modify_conformer_torsion_angles( - rigid_new_pos, - data["ligand", "ligand"].edge_index.T[data["ligand"].edge_mask], - ( - data["ligand"].mask_rotate - if isinstance(data["ligand"].mask_rotate, np.ndarray) - else data["ligand"].mask_rotate[0] - ), - torsion_updates, - ).to(rigid_new_pos.device) - R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T) - aligned_flexible_pos = flexible_new_pos @ R.T + t.T - data["ligand"].pos = aligned_flexible_pos - else: - data["ligand"].pos = rigid_new_pos - return data - - -def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device): - complex_graphs["ligand"].node_t = { - "tr": t_tr * torch.ones(complex_graphs["ligand"].num_nodes).to(device), - "rot": t_rot * torch.ones(complex_graphs["ligand"].num_nodes).to(device), - "tor": t_tor * torch.ones(complex_graphs["ligand"].num_nodes).to(device), - } - complex_graphs["receptor"].node_t = { - "tr": t_tr * torch.ones(complex_graphs["receptor"].num_nodes).to(device), - "rot": t_rot * torch.ones(complex_graphs["receptor"].num_nodes).to(device), - "tor": t_tor * torch.ones(complex_graphs["receptor"].num_nodes).to(device), - } - complex_graphs.complex_t = { - "tr": t_tr * torch.ones(batchsize).to(device), - "rot": t_rot * torch.ones(batchsize).to(device), - "tor": t_tor * torch.ones(batchsize).to(device), - } - if all_atoms: - complex_graphs["atom"].node_t = { - "tr": t_tr * torch.ones(complex_graphs["atom"].num_nodes).to(device), - "rot": t_rot * torch.ones(complex_graphs["atom"].num_nodes).to(device), - "tor": t_tor * torch.ones(complex_graphs["atom"].num_nodes).to(device), - } - - -class GenerateNoise: - """Apply forward diffusion on the ligand - - Args: - t_to_sigma (Callable): Callable to embed time - no_torsion (bool): if not to perturb ligand torsion degrees - all_atom (bool): all atom or coarse grained/residue for protein - copy_ref_pos (bool): whether or not make a copy of the input ligand position - """ - - def __init__( - self, - t_to_sigma: Callable[[float, float, float], Tuple[float, float, float]], - no_torsion: bool, - all_atom: bool, - copy_ref_pos: bool = False, - ): - self.t_to_sigma = t_to_sigma - self.no_torsion = no_torsion - self.all_atom = all_atom - self._copy_ref_pos = copy_ref_pos - - def __call__(self, source: Generator[HeteroData, None, None]) -> Generator[HeteroData, None, None]: - for (data,) in source: - if self._copy_ref_pos: - data["ligand"].aligned_pos = deepcopy(data["ligand"].pos) - t = np.random.uniform() - t_tr, t_rot, t_tor = t, t, t - yield self.apply_noise(data, t_tr, t_rot, t_tor) - - def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update=None, rot_update=None, torsion_updates=None): - if not torch.is_tensor(data["ligand"].pos): - data["ligand"].pos = random.choice(data["ligand"].pos) - - tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) - set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) - - tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update - rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update - torsion_updates = ( - np.random.normal(loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()) - if torsion_updates is None - else torsion_updates - ) - torsion_updates = None if self.no_torsion else torsion_updates - modify_conformer( - data, - tr_update, - torch.from_numpy(rot_update).float(), - None if data["ligand"].edge_mask.sum() == 0 else torsion_updates, - ) - - data.tr_score = -tr_update / tr_sigma**2 - data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0) - data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() - data.tor_sigma_edge = None if self.no_torsion else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma - return data diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py deleted file mode 100644 index 648045d261..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/geometry.py +++ /dev/null @@ -1,125 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import math - -import torch - - -def quaternion_to_matrix(quaternions): - """ - From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html - Convert rotations given as quaternions to rotation matrices. - - Args: - quaternions: quaternions with real part first, - as tensor of shape (..., 4). - - Returns: - Rotation matrices as tensor of shape (..., 3, 3). - """ - r, i, j, k = torch.unbind(quaternions, -1) - two_s = 2.0 / (quaternions * quaternions).sum(-1) - - o = torch.stack( - ( - 1 - two_s * (j * j + k * k), - two_s * (i * j - k * r), - two_s * (i * k + j * r), - two_s * (i * j + k * r), - 1 - two_s * (i * i + k * k), - two_s * (j * k - i * r), - two_s * (i * k - j * r), - two_s * (j * k + i * r), - 1 - two_s * (i * i + j * j), - ), - -1, - ) - return o.reshape(quaternions.shape[:-1] + (3, 3)) - - -def axis_angle_to_quaternion(axis_angle): - """ - From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html - Convert rotations given as axis/angle to quaternions. - - Args: - axis_angle: Rotations given as a vector in axis angle form, - as a tensor of shape (..., 3), where the magnitude is - the angle turned anticlockwise in radians around the - vector's direction. - - Returns: - quaternions with real part first, as tensor of shape (..., 4). - """ - angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) - half_angles = 0.5 * angles - eps = 1e-6 - small_angles = angles.abs() < eps - sin_half_angles_over_angles = torch.empty_like(angles) - sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] - # for x small, sin(x/2) is about x/2 - (x/2)^3/6 - # so sin(x/2)/x is about 1/2 - (x*x)/48 - sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 - quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1) - return quaternions - - -def axis_angle_to_matrix(axis_angle): - """ - From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html - Convert rotations given as axis/angle to rotation matrices. - - Args: - axis_angle: Rotations given as a vector in axis angle form, - as a tensor of shape (..., 3), where the magnitude is - the angle turned anticlockwise in radians around the - vector's direction. - - Returns: - Rotation matrices as tensor of shape (..., 3, 3). - """ - return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) - - -def rigid_transform_Kabsch_3D_torch(A, B): - # R = 3x3 rotation matrix, t = 3x1 column vector - # This already takes residue identity into account. - - assert A.shape[1] == B.shape[1] - num_rows, num_cols = A.shape - if num_rows != 3: - raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") - num_rows, num_cols = B.shape - if num_rows != 3: - raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") - - # find mean column wise: 3 x 1 - centroid_A = torch.mean(A, axis=1, keepdims=True) - centroid_B = torch.mean(B, axis=1, keepdims=True) - - # subtract mean - Am = A - centroid_A - Bm = B - centroid_B - - H = Am @ Bm.T - - # find rotation - U, S, Vt = torch.linalg.svd(H) - - R = Vt.T @ U.T - # special reflection case - if torch.linalg.det(R) < 0: - SS = torch.diag(torch.tensor([1.0, 1.0, -1.0], device=A.device)) - R = (Vt.T @ SS) @ U.T - assert math.fabs(torch.linalg.det(R) - 1) < 3e-3 # note I had to change this error bound to be higher - - t = -R @ centroid_A + centroid_B - return R, t diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py deleted file mode 100644 index 7fa0da0a3e..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/so3.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - - -import os - -import numpy as np -import torch -from scipy.spatial.transform import Rotation - - -package_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) - -MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000 -X_N = 2000 - -omegas = np.linspace(0, np.pi, X_N + 1)[1:] - -# TODO generating these arrays is super slow, we should vectorize this - - -def _compose(r1, r2): # R1 @ R2 but for Euler vecs - return Rotation.from_matrix( - Rotation.from_rotvec(r1).as_matrix() @ Rotation.from_rotvec(r2).as_matrix() - ).as_rotvec() - - -def _expansion(omega, eps, L=2000): # the summation term only - p = 0 - for l in range(L): - p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) - return p - - -def _expansion_vectorized(omega, eps, L=2000): - l = np.arange(L).reshape((-1, 1)) - omega = omega.reshape((1, -1)) - eps = eps.reshape((1, -1)) - - p1 = (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) - p2 = np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) - p = np.matmul(p2.T, p1).T - return p - - -def _density(expansion, omega, marginal=True): # if marginal, density over [0, pi], else over SO(3) - if marginal: - return expansion * (1 - np.cos(omega)) / np.pi - else: - return expansion / 8 / np.pi**2 # the constant factor doesn't affect any actual calculations though - - -def _score(exp, omega, eps, L=2000): # score of density over SO(3) - dSigma = 0 - for l in range(L): - hi = np.sin(omega * (l + 1 / 2)) - dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) - lo = np.sin(omega / 2) - dlo = 1 / 2 * np.cos(omega / 2) - dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo**2 - return dSigma / exp - - -def _score_vectorized(exp, omega, eps, L=2000): # score of density over SO(3) - dSigma = 0 - l = np.arange(L).reshape((-1, 1)) - omega = omega.reshape((1, -1)) - eps = eps.reshape((1, -1)) - - hi = np.sin(omega * (l + 1 / 2)) - dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) - lo = np.sin(omega / 2) - dlo = 1 / 2 * np.cos(omega / 2) - dSigma1 = (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) - dSigma2 = (lo * dhi - hi * dlo) / lo**2 - dSigma = np.matmul(dSigma2.T, dSigma1).T - return dSigma / exp - - -def _score_small_eps(omega, eps): - # formula for f(omega, eps) in eq (5) https://openreview.net/pdf?id=jHA-yCyBGb - # score = d(log(f(omega, eps^2)) / d omega - # for our range of omegas, this approximation works well for eps up to ~0.7 - # note that for numerical stability it is important to combine - # exp(pi*omega/eps) * exp(-pi**2/eps) into exp(pi*(omega-pi)/eps) - - x = omega.reshape((1, -1)) - a = eps.reshape((-1, 1)) ** 2 - - return ( - -0.5 * x / a - + ( - 1 - + -np.exp(np.pi * (x - np.pi) / a) - + -np.exp(-np.pi * (x + np.pi) / a) - + -(np.pi * (x - 2 * np.pi) / a) * np.exp(np.pi * (x - np.pi) / a) - + np.pi * (x + 2 * np.pi) / a * np.exp(-np.pi * (x + np.pi) / a) - ) - / (x + -(x - 2 * np.pi) * np.exp(np.pi * (x - np.pi) / a) + (x + 2 * np.pi) * np.exp(-np.pi * (x + np.pi) / a)) - - 0.5 * np.cos(x / 2) / np.sin(x / 2) - ) - - -if os.path.exists(os.path.join(package_path, ".so3.npz")): - so3 = np.load(os.path.join(package_path, ".so3.npz")) - _omegas_array = so3["_omegas_array"] - _cdf_vals = so3["_cdf_vals"] - _score_norms = so3["_score_norms"] - _exp_score_norms = so3["_exp_score_norms"] -else: - _eps_array = (10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS)).astype(np.float128) - _omegas_array = np.linspace(0, np.pi, X_N + 1)[1:].astype(np.float128) - - _exp_vals = _expansion_vectorized(_omegas_array, _eps_array) - _pdf_vals = _density(_exp_vals, _omegas_array, marginal=True) - _cdf_vals = _pdf_vals.cumsum(1) / X_N * np.pi - _score_norms = np.zeros((N_EPS, X_N)) - _small_eps_idx = _eps_array < 0.5 - _score_norms[_small_eps_idx] = _score_small_eps(_omegas_array, _eps_array[_small_eps_idx]) - _score_norms[~_small_eps_idx] = _score_vectorized( - _exp_vals[~_small_eps_idx], _omegas_array, _eps_array[~_small_eps_idx] - ) - - _exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi) - - _omegas_array = _omegas_array.astype(np.float64) - _cdf_vals = _cdf_vals.astype(np.float64) - _score_norms = _score_norms.astype(np.float64) - _exp_score_norms = _exp_score_norms.astype(np.float64) - - np.savez( - os.path.join(package_path, ".so3.npz"), - _omegas_array=_omegas_array, - _cdf_vals=_cdf_vals, - _score_norms=_score_norms, - _exp_score_norms=_exp_score_norms, - ) - - -def sample(eps): - eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS - eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) - x = np.random.rand() - return np.interp(x, _cdf_vals[eps_idx], _omegas_array) - - -def sample_vec(eps): - x = np.random.randn(3) - x /= np.linalg.norm(x) - return x * sample(eps) - - -def score_vec(eps, vec): - eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS - eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) - - om = np.linalg.norm(vec) - return np.interp(om, _omegas_array, _score_norms[eps_idx]) * vec / om - - -def score_norm(eps): - device = eps.device - eps = eps.cpu().numpy() - eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS - eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) - return torch.from_numpy(_exp_score_norms[eps_idx]).to(device=device, dtype=torch.float) diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py deleted file mode 100644 index 3134d26a7e..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torsion.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import copy - -import numpy as np -import torch -from scipy.spatial.transform import Rotation as R - - -""" - Preprocessing and computation for torsional updates to conformers -""" - - -def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False): - pos = copy.deepcopy(pos) - if type(pos) != np.ndarray: - pos = pos.cpu().numpy() - - for idx_edge, e in enumerate(edge_index.cpu().numpy()): - if torsion_updates[idx_edge] == 0: - continue - u, v = e[0], e[1] - - # check if need to reverse the edge, v should be connected to the part that gets rotated - assert not mask_rotate[idx_edge, u] - assert mask_rotate[idx_edge, v] - - rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards - rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge! - rot_mat = R.from_rotvec(rot_vec).as_matrix() - - pos[mask_rotate[idx_edge]] = (pos[mask_rotate[idx_edge]] - pos[v]) @ rot_mat.T + pos[v] - - if not as_numpy: - pos = torch.from_numpy(pos.astype(np.float32)) - return pos - - diff --git a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py b/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py deleted file mode 100644 index 977cff937d..0000000000 --- a/sub-packages/bionemo-diffdock/src/bionemo/diffdock/utils/torus.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import os - -import numpy as np -import tqdm - - -""" - Preprocessing for the SO(2)/torus sampling and score computations, truncated infinite series are computed and then - cached to memory, therefore the precomputation is only run the first time the repository is run on a machine -""" - -package_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) - - -def p(x, sigma, N=10): - p_ = 0 - for i in tqdm.trange(-N, N + 1): - p_ += np.exp(-((x + 2 * np.pi * i) ** 2) / 2 / sigma**2) - return p_ - - -def grad(x, sigma, N=10): - p_ = 0 - for i in tqdm.trange(-N, N + 1): - p_ += (x + 2 * np.pi * i) / sigma**2 * np.exp(-((x + 2 * np.pi * i) ** 2) / 2 / sigma**2) - return p_ - - -X_MIN, X_N = 1e-5, 5000 # relative to pi -SIGMA_MIN, SIGMA_MAX, SIGMA_N = 3e-3, 2, 5000 # relative to pi - -x = 10 ** np.linspace(np.log10(X_MIN), 0, X_N + 1) * np.pi -sigma = 10 ** np.linspace(np.log10(SIGMA_MIN), np.log10(SIGMA_MAX), SIGMA_N + 1) * np.pi - -if os.path.exists(os.path.join(package_path, ".torus.npz")): - torus = np.load(os.path.join(package_path, ".torus.npz")) - p_ = torus["p_"] - score_ = torus["score_"] -else: - p_ = p(x, sigma[:, None], N=100) - score_ = grad(x, sigma[:, None], N=100) / p_ - - np.savez(os.path.join(package_path, ".torus.npz"), p_=p_, score_=score_) - - -def score(x, sigma): - x = (x + np.pi) % (2 * np.pi) - np.pi - sign = np.sign(x) - x = np.log(np.abs(x) / np.pi) - x = (x - np.log(X_MIN)) / (0 - np.log(X_MIN)) * X_N - x = np.round(np.clip(x, 0, X_N)).astype(int) - sigma = np.log(sigma / np.pi) - sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N - sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) - return -sign * score_[sigma, x] - - -def p(x, sigma): - x = (x + np.pi) % (2 * np.pi) - np.pi - x = np.log(np.abs(x) / np.pi) - x = (x - np.log(X_MIN)) / (0 - np.log(X_MIN)) * X_N - x = np.round(np.clip(x, 0, X_N)).astype(int) - sigma = np.log(sigma / np.pi) - sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N - sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) - return p_[sigma, x] - - -def sample(sigma, seed=None): - if seed is None: - out = sigma * np.random.randn(*sigma.shape) - else: - rng = np.random.default_rng(seed) - out = sigma * rng.normal(size=sigma.shape) - out = (out + np.pi) % (2 * np.pi) - np.pi - return out - - -class TorusScoreNorm: - _score_norm = None - - def __init__(self, seed=None): - if TorusScoreNorm._score_norm is None: - _score_norm = score( - sample(sigma[None].repeat(10000, 0).flatten(), seed=seed), sigma[None].repeat(10000, 0).flatten() - ).reshape(10000, -1) - TorusScoreNorm._score_norm = (_score_norm**2).mean(0) - - def __call__(self, sigma): - sigma = np.log(sigma / np.pi) - sigma = (sigma - np.log(SIGMA_MIN)) / (np.log(SIGMA_MAX) - np.log(SIGMA_MIN)) * SIGMA_N - sigma = np.round(np.clip(sigma, 0, SIGMA_N)).astype(int) - return TorusScoreNorm._score_norm[sigma] diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py deleted file mode 100644 index 8dc9ab600f..0000000000 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/conftest.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - - -from pathlib import Path -from enum import Enum, auto -from functools import partial -import random - -import lightning as L -import pytest -import torch -from torch_geometric.loader.dataloader import Collater -import webdataset as wds -from webdataset.filters import batched, shuffle - -from bionemo.core.data.datamodule import PickledDataWDS, Split -from bionemo.diffdock.utils.data import SelectPoseAndLabelData, SizeAwareBatching, estimate_size -from bionemo.diffdock.utils.diffusion import GenerateNoise, t_to_sigma - - -@pytest.fixture(scope="module") -def get_path(request): - path_test = Path(request.module.__file__).resolve() - dir_test = path_test.parents[0] - dir_data = path_test.parents[6] / "test_data" / \ - "diffdock" / "pyg_heterodata_pickled" - return str(dir_test), str(dir_data) - - -class DiffDockModel(Enum): - score = auto() - confidence = auto() - - -@pytest.fixture(scope="module", params=list(DiffDockModel)) -def get_diffdock_heterodata(get_path, request): - _, dir_data = get_path - model = request.param - name_model = str(model).split(".")[-1] - dir_heterodata = f"{dir_data}/{name_model}_model" - suffix_heterodata = "heterodata.pyd" - names = { - Split.train: [ - "6t88", - "6vs3", - "6wtn", - "6yqv", - "7amc", - "7bmi", - "7cuo", - "7d5c", - "7din", - "7fha", - "7jnb", - "7k0v", - "7kb1", - "7km8", - "7l7c", - "7lcu", - "7msr", - "7my1", - "7n6f", - "7np6", - ], - Split.val: ["7nr6", "7oeo", "7oli", "7oso", "7p5t", "7q5i", "7qhl", "7rh3", "7rzl", "7sgv"], - Split.test: ["7sne", "7t2i", "7tbu", "7tsf", "7umv", "7up3", "7uq3", "7wpw", "7xek", "7xij"], - } - return (dir_heterodata, suffix_heterodata, names, model) - - -def _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): - prefix_dir_tars_wds = tmp_path_factory.mktemp("diffdock_score_model_tars_wds").as_posix() - tr_sigma_min, tr_sigma_max = (0.1, 19) - rot_sigma_min, rot_sigma_max = (0.03, 1.55) - tor_sigma_min, tor_sigma_max = (0.0314, 3.14) - is_all_atom = False - no_torsion = False - sigma_t = partial( - t_to_sigma, tr_sigma_min, tr_sigma_max, rot_sigma_min, rot_sigma_max, tor_sigma_min, tor_sigma_max - ) - seed_rng_shfl = 822782392 - # webdataset pipeline - pipeline_wds = { - Split.train: [GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), - shuffle(len(names[Split.train]), - rng=random.Random(seed_rng_shfl))], - Split.val: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=True), - Split.test: GenerateNoise(sigma_t, no_torsion, is_all_atom, copy_ref_pos=False), - } - local_batch_size = 2 - global_batch_size = 2 - size_cuda_mem = 0.85 * torch.cuda.get_device_properties("cuda:0").total_memory / 2**20 - batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) - # WebLoader pipeline - pipelines_wdl_batch = { - Split.train: [shuffle(40, rng=random.Random(seed_rng_shfl)), - SizeAwareBatching(max_total_size=size_cuda_mem, - size_fn=estimate_size, - no_single_sample=True)], - Split.val: batch_pyg, - Split.test: batch_pyg, - } - n_tars_wds = 4 - kwargs_wds = { - split : {'shardshuffle' : split == Split.train, - 'nodesplitter' : wds.split_by_node, - 'seed' : seed_rng_shfl} - for split in Split - } - kwargs_wld = { - Split.train: {"num_workers": 2}, - Split.val: {"num_workers": 2}, - Split.test: {"num_workers": 2}, - } - data_module = PickledDataWDS( - dir_heterodata, - suffix_heterodata, - names, - prefix_dir_tars_wds, - global_batch_size, - n_tars_wds=n_tars_wds, - prefix_tars_wds="heterographs", - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipelines_wdl_batch, - kwargs_wds=kwargs_wds, - kwargs_wld=kwargs_wld, - ) - return data_module, prefix_dir_tars_wds - - -def _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names): - prefix_dir_tars_wds = tmp_path_factory.mktemp("diffdock_confidence_model_tars_wds").as_posix() - # webdataset pipeline - rmsd_classification_cutoff = 2.0 - samples_per_complex = 7 - balance = False - is_all_atom = True - seed_rng_shfl = 822782392 - select_pose = SelectPoseAndLabelData( - rmsd_classification_cutoff, samples_per_complex, balance, is_all_atom, seed=seed_rng_shfl - ) - pipeline_wds = { - Split.train: [select_pose, shuffle(len(names[Split.train]), - rng=random.Random(seed_rng_shfl))], - Split.val: select_pose, - Split.test: select_pose, - } - local_batch_size = 2 - global_batch_size = 2 - batch_pyg = batched(local_batch_size, collation_fn=Collater(dataset=[], follow_batch=None, exclude_keys=None)) - # WebLoader pipeline - pipelines_wdl_batch = { - Split.train: [shuffle(40, rng=random.Random(seed_rng_shfl)), batch_pyg], - Split.val: batch_pyg, - Split.test: batch_pyg, - } - n_tars_wds = 4 - kwargs_wds = { - split : {'shardshuffle' : split == Split.train, - 'nodesplitter' : wds.split_by_node, - 'seed' : seed_rng_shfl} - for split in Split - } - kwargs_wld = { - Split.train: {"num_workers": 2}, - Split.val: {"num_workers": 2}, - Split.test: {"num_workers": 2}, - } - data_module = PickledDataWDS( - dir_heterodata, - suffix_heterodata, - names, - prefix_dir_tars_wds, - global_batch_size, - n_tars_wds=n_tars_wds, - prefix_tars_wds="heterographs", - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipelines_wdl_batch, - kwargs_wds=kwargs_wds, - kwargs_wld=kwargs_wld, - ) - return data_module, prefix_dir_tars_wds - - -@pytest.fixture(scope="module") -def create_datamodule(tmp_path_factory, get_diffdock_heterodata): - dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata - if model == DiffDockModel.score: - return _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) - elif model == DiffDockModel.confidence: - return _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) - - -@pytest.fixture(scope="module") -def create_another_datamodule(tmp_path_factory, get_diffdock_heterodata): - dir_heterodata, suffix_heterodata, names, model = get_diffdock_heterodata - if model == DiffDockModel.score: - return _create_datamodule_score_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) - elif model == DiffDockModel.confidence: - return _create_datamodule_confidence_model_impl(tmp_path_factory, dir_heterodata, suffix_heterodata, names) - - -class ModelTestDiffDock(L.LightningModule): - def __init__(self) -> None: - super().__init__() - self._model = torch.nn.Linear(3, 3) - self._samples = {split: [] for split in Split} - - def forward(self, x): - return self._model(x["ligand"].pos) - - def training_step(self, batch): - self._samples[Split.train].append(batch.name) - loss = self(batch).sum() - return loss - - def validation_step(self, batch, batch_index): - self._samples[Split.val].append(batch.name) - return torch.zeros(1) - - def test_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) - - def predict_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) - return torch.zeros(1) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) - return optimizer - - -@pytest.fixture(scope="function") -def create_trainer_and_model(): - trainer = L.Trainer(max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1) - model = ModelTestDiffDock() - return trainer, model diff --git a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py b/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py deleted file mode 100644 index 95ca4a88fe..0000000000 --- a/sub-packages/bionemo-diffdock/tests/bionemo/diffdock/data/test_diffdock_datamodule.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - - -import glob -from enum import Enum, auto - -import lightning -import pytest -import torch -from torch_geometric.data import Batch, HeteroData - -from bionemo.core.data.datamodule import Split - - -@pytest.mark.parametrize("split", list(Split)) -def test_datamodule_init(split, get_diffdock_heterodata, create_datamodule): - name_split = str(split).split(".")[1] - (_, _, names, model) = get_diffdock_heterodata - data_module, prefix_dir_tars_wds = create_datamodule - assert data_module._n_samples[split] == len(names[split]), ( - f"Wrong {split}-set size for {model} model: " - f"expected {len(names[split])} " - f"but got {data_module._n_samples[split]}" - ) - assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( - f"Wrong tar files directory for {model} model: " - f"expected {prefix_dir_tars_wds}{split} " - f"but got {data_module._dirs_tars_wds[split]}" - ) - - -@pytest.mark.parametrize("split", list(Split)) -def test_datamodule_prepare_data(split, create_datamodule): - data_module, _ = create_datamodule - # LightningDataModule.prepare_data() is supposed to be called from the main - # process in a Lightning-managed multi-process context so we can call it in - # a single process - data_module.prepare_data() - files_tars = sorted(glob.glob(f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar")) - assert len(files_tars) >= data_module._n_tars_wds, ( - f"Wrong num of {split}-set tar files: " f"expected {data_module._n_tars_wds} " f"got {len(files_tars)}" - ) - - -@pytest.mark.parametrize("split", list(Split)) -def test_datamodule_setup_dataset(split, create_datamodule, create_another_datamodule): - data_modules = [create_datamodule[0], create_another_datamodule[0]] - lists_complex_name = [] - lists_pos_ligand = [] - for m in data_modules: - m.prepare_data() - # run through all the possible stages first to setup all the correps. - # dataset objects - m.setup("fit") - m.setup("test") - lightning.seed_everything(2823828) - names = [] - pos_ligand = [] - for sample in m._dataset[split]: - assert isinstance(sample, HeteroData), "Sample yield from dataset is not PyG HeteroData" - names.append(sample.name) - pos_ligand.append(sample["ligand"].pos) - lists_complex_name.append(names) - lists_pos_ligand.append(pos_ligand) - - assert len(lists_complex_name[0]) > 0, "No names in {split} dataset" - assert lists_complex_name[0] == lists_complex_name[1], ( - f"Inconsistent sample name in {split}-set from data module instances: " - f"{lists_complex_name[0]} \n\nvs.\n\n" - f"{lists_complex_name[1]}" - ) - - assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataset" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]), ( - f"Inconsistent number of ligand position in {split}-set from data " - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n" - f"{len(lists_pos_ligand[1])}" - ) - for i in range(len(lists_pos_ligand[0])): - pos_0 = lists_pos_ligand[0][i] - pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close( - pos_0, - pos_1, - msg=lambda m: f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}", - ) - - -@pytest.mark.parametrize("split", list(Split)) -def test_datamodule_setup_dataloader(split, create_datamodule, create_another_datamodule): - data_modules = [create_datamodule[0], create_another_datamodule[0]] - lists_complex_name = [] - lists_pos_ligand = [] - for m in data_modules: - m.prepare_data() - # run through all the possible stages first to setup all the correps. - # dataset objects - m.setup("fit") - m.setup("test") - lightning.seed_everything(2823828) - names = [] - pos_ligand = [] - loader = None - if split == Split.train: - loader = m.train_dataloader() - elif split == Split.val: - loader = m.val_dataloader() - elif split == Split.test: - loader = m.test_dataloader() - else: - raise RuntimeError(f"Test for split {split} not implemented") - assert loader is not None, "dataloader not instantated" - for samples in loader: - # PyG's HeteroDataBatch is Batch inherited from HeteroData - assert isinstance(samples, Batch), "Sample object is not PyG Batch" - assert isinstance(samples, HeteroData), "Sample object is not PyG HeteroData" - names.append(samples.name) - pos_ligand.append(samples["ligand"].pos) - lists_complex_name.append(names) - lists_pos_ligand.append(pos_ligand) - - assert len(lists_complex_name[0]) > 0, "No names in {split} dataloader" - assert lists_complex_name[0] == lists_complex_name[1], ( - f"Inconsistent sample name in {split}-set from data module instances: " - f"{lists_complex_name[0]} \n\nvs.\n\n" - f"{lists_complex_name[1]}" - ) - - assert len(lists_pos_ligand[0]) > 0, "No ligand position found in dataloader" - assert len(lists_pos_ligand[0]) == len(lists_pos_ligand[1]), ( - f"Inconsistent number of ligand position in {split}-set from data " - f"module instances: {len(lists_pos_ligand[0])} \n\nvs.\n\n" - f"{len(lists_pos_ligand[1])}" - ) - for i in range(len(lists_pos_ligand[0])): - pos_0 = lists_pos_ligand[0][i] - pos_1 = lists_pos_ligand[1][i] - torch.testing.assert_close( - pos_0, - pos_1, - msg=lambda m: f"Inconsistent ligand position in the " - f"{i}'th sample/batch of {split}-set " - f"between two data module instances:\n\n{m}", - ) - - -class Stage(Enum): - fit = auto() - validate = auto() - test = auto() - predict = auto() - - -@pytest.mark.parametrize("stage", list(Stage)) -def test_datamodule_in_lightning(stage, create_datamodule, create_another_datamodule, create_trainer_and_model): - data_modules = [create_datamodule[0], create_another_datamodule[0]] - trainer, model = create_trainer_and_model - # get the list of samples from the loader - lightning.seed_everything(2823828) - data_modules[0].prepare_data() - split = None - if stage == Stage.fit: - split = Split.train - elif stage == Stage.validate: - split = Split.val - elif stage == Stage.test or stage == Stage.predict: - split = Split.test - else: - raise RuntimeError(f"{stage} stage not implemented") - name_stage = str(stage).split(".")[-1] - data_modules[0].setup(name_stage) - # get the list of samples from the workflow - get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") - loader = get_dataloader() - samples = [ sample.name for sample in loader ] - lightning.seed_everything(2823828) - workflow = getattr(trainer, name_stage) - workflow(model, data_modules[1]) - assert model._samples[split] == samples From 57b6c3a2cc2028e38e6c480ca0dedab896b1ebb7 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 00:27:29 +0000 Subject: [PATCH 56/70] BugFix: change license header --- scripts/license_check.py | 4 +- .../src/bionemo/webdatamodule/datamodule.py | 33 ++-- .../src/bionemo/webdatamodule/utils.py | 19 ++- .../tests/bionemo/webdatamodule/__init__.py | 0 .../tests/bionemo/webdatamodule/conftest.py | 155 +++++++++--------- .../webdatamodule/test_webdatamodule.py | 70 ++++---- 6 files changed, 135 insertions(+), 146 deletions(-) delete mode 100644 sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py diff --git a/scripts/license_check.py b/scripts/license_check.py index 631285cce6..ea681f5c9e 100644 --- a/scripts/license_check.py +++ b/scripts/license_check.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import click @@ -66,7 +66,7 @@ def __str__(self) -> str: return f"{self.pyfile.name} does not have the license header!" -LicenseCheckError = IOError | SyntaxError | HeaderNotFound +LicenseCheckError = Union[IOError, SyntaxError, HeaderNotFound] """Errors that can be encountered during the license check process. Specific errors and their underlying causes: diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py index b2af47652d..80b938e969 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -1,16 +1,20 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import glob -import random from enum import Enum, auto from typing import Any, Dict, Iterable, List, Optional, Union, get_args @@ -206,9 +210,7 @@ def __init__( self._global_batch_size = global_batch_size - - if not isinstance(suffix_keys_wds, - get_args(Union[str, Iterable[str]])): + if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable[str]])): raise TypeError("suffix_keys_wds can only be str or Iterable[str]") self._suffix_keys_wds = suffix_keys_wds @@ -244,18 +246,13 @@ def _setup_wds(self, split: Split) -> wds.WebDataset: """ if split not in self._dirs_tars_wds.keys(): raise RuntimeError(f"_setup_wds() is called with {split} " f"split that doesn't have the input tar dir") - is_train = split == Split.train urls = sorted(glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar")) kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None - dataset = ( - wds.WebDataset(urls, **(kwargs if kwargs is not None else {})) - .decode() - ) + dataset = wds.WebDataset(urls, **(kwargs if kwargs is not None else {})).decode() if isinstance(self._suffix_keys_wds, str): dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}") else: - dataset = dataset.extract_keys(*[f"*.{key}" for key in - self._suffix_keys_wds]) + dataset = dataset.extract_keys(*[f"*.{key}" for key in self._suffix_keys_wds]) if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: if isinstance(self._pipeline_wds[split], Iterable): diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index f8898c3e02..88d61bb45e 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -1,12 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 1045b2b999..35ff363c9d 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -1,12 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os @@ -19,7 +24,7 @@ import webdataset as wds from webdataset.filters import batched, shuffle -from bionemo.webdatamodule.datamodule import WebDataModule, Split, PickledDataWDS +from bionemo.webdatamodule.datamodule import PickledDataWDS, Split, WebDataModule from bionemo.webdatamodule.utils import pickles_to_tars @@ -40,60 +45,54 @@ def gen_test_data(tmp_path_factory): t = torch.tensor(i, dtype=torch.int32) pickle.dump(t, open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb")) # generate the tars - pickles_to_tars(dir_pickles, suffix_sample, prefix_subset, dir_tars, - prefix_tar, min_num_shards=3) - return (dir_pickles, dir_tars, prefix_sample, suffix_sample, prefix_tar, - n_samples) + pickles_to_tars(dir_pickles, suffix_sample, prefix_subset, dir_tars, prefix_tar, min_num_shards=3) + return (dir_pickles, dir_tars, prefix_sample, suffix_sample, prefix_tar, n_samples) def _create_webdatamodule(gen_test_data): - (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, - n_samples_in_tar) = gen_test_data + (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = gen_test_data local_batch_size = 2 global_batch_size = 2 seed_rng_shfl = 82838392 - dirs_tars_wds = { split : dir_tars_wds for split in Split } + dirs_tars_wds = {split: dir_tars_wds for split in Split} - n_samples = { split : n_samples_in_tar for split in Split } + n_samples = {split: n_samples_in_tar for split in Split} - batch = batched(local_batch_size, collation_fn=lambda - list_samples : torch.vstack(list_samples)) + batch = batched(local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)) - untuple = lambda source : (sample for (sample,) in source) + untuple = lambda source: (sample for (sample,) in source) pipeline_wds = { - Split.train : [untuple, shuffle(n_samples[Split.train], - rng=random.Random(seed_rng_shfl))], - Split.val : untuple, - Split.test : untuple - } + Split.train: [untuple, shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl))], + Split.val: untuple, + Split.test: untuple, + } pipeline_prebatch_wld = { - Split.train: [shuffle(n_samples[Split.train], - rng=random.Random(seed_rng_shfl)), batch], - Split.val : batch, - Split.test : batch - } + Split.train: [shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), batch], + Split.val: batch, + Split.test: batch, + } kwargs_wds = { - split : {'shardshuffle' : split == Split.train, - 'nodesplitter' : wds.split_by_node, - 'seed' : seed_rng_shfl} + split: {"shardshuffle": split == Split.train, "nodesplitter": wds.split_by_node, "seed": seed_rng_shfl} for split in Split - } - - kwargs_wld = { - split : {"num_workers": 2} for split in Split - } - - data_module = WebDataModule(dirs_tars_wds, n_samples, suffix_keys_wds, - global_batch_size, - prefix_tars_wds=prefix_tars_wds, - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipeline_prebatch_wld, - kwargs_wds=kwargs_wds, - kwargs_wld=kwargs_wld) + } + + kwargs_wld = {split: {"num_workers": 2} for split in Split} + + data_module = WebDataModule( + dirs_tars_wds, + n_samples, + suffix_keys_wds, + global_batch_size, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) return data_module, dir_tars_wds @@ -146,8 +145,7 @@ def create_trainer_and_model(): def _create_pickleddatawds(tmp_path_factory, gen_test_data): - (dir_pickles, _, prefix_sample, suffix_keys_wds, prefix_tars_wds, - n_samples_in_tar) = gen_test_data + (dir_pickles, _, prefix_sample, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = gen_test_data local_batch_size = 2 global_batch_size = 2 seed_rng_shfl = 82838392 @@ -155,49 +153,46 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() - names = { split : [ f"{prefix_sample}-{i:04d}" for i in - range(n_samples_in_tar) ] for split in Split - } + names = {split: [f"{prefix_sample}-{i:04d}" for i in range(n_samples_in_tar)] for split in Split} - n_samples = { split : n_samples_in_tar for split in Split } + n_samples = {split: n_samples_in_tar for split in Split} - batch = batched(local_batch_size, collation_fn=lambda - list_samples : torch.vstack(list_samples)) + batch = batched(local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)) - untuple = lambda source : (sample for (sample,) in source) + untuple = lambda source: (sample for (sample,) in source) pipeline_wds = { - Split.train : [untuple, shuffle(n_samples[Split.train], - rng=random.Random(seed_rng_shfl))], - Split.val : untuple, - Split.test : untuple - } + Split.train: [untuple, shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl))], + Split.val: untuple, + Split.test: untuple, + } pipeline_prebatch_wld = { - Split.train: [shuffle(n_samples[Split.train], - rng=random.Random(seed_rng_shfl)), batch], - Split.val : batch, - Split.test : batch - } + Split.train: [shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), batch], + Split.val: batch, + Split.test: batch, + } kwargs_wds = { - split : {'shardshuffle' : split == Split.train, - 'nodesplitter' : wds.split_by_node, - 'seed' : seed_rng_shfl} + split: {"shardshuffle": split == Split.train, "nodesplitter": wds.split_by_node, "seed": seed_rng_shfl} for split in Split - } - - kwargs_wld = { - split : {"num_workers": 2} for split in Split - } - - data_module = PickledDataWDS(dir_pickles, suffix_keys_wds, names, - prefix_dir_tars_wds, global_batch_size, - n_tars_wds=n_tars_wds, - prefix_tars_wds=prefix_tars_wds, - pipeline_wds=pipeline_wds, - pipeline_prebatch_wld=pipeline_prebatch_wld, - kwargs_wds=kwargs_wds, kwargs_wld=kwargs_wld) + } + + kwargs_wld = {split: {"num_workers": 2} for split in Split} + + data_module = PickledDataWDS( + dir_pickles, + suffix_keys_wds, + names, + prefix_dir_tars_wds, + global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) return data_module, prefix_dir_tars_wds diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index 0093810808..bc82cc3202 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -1,20 +1,24 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# SPDX-License-Identifier: LicenseRef-Apache2 # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from enum import Enum, auto +import lightning as L import pytest - import torch -import lightning as L from bionemo.webdatamodule.datamodule import Split @@ -23,9 +27,7 @@ def test_webdatamodule_init(split, create_webdatamodule): data_module, prefix_dir_tars_wds = create_webdatamodule assert data_module._n_samples[split] == 10, ( - f"Wrong {split}-set size: " - f"expected 10 " - f"but got {data_module._n_samples[split]}" + f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" ) assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}", ( f"Wrong tar files directory: " @@ -35,8 +37,7 @@ def test_webdatamodule_init(split, create_webdatamodule): @pytest.mark.parametrize("split", list(Split)) -def test_webdatamodule_setup_dataset(split, create_webdatamodule, - create_another_webdatamodule): +def test_webdatamodule_setup_dataset(split, create_webdatamodule, create_another_webdatamodule): data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] lists_tensors = [] for m in data_modules: @@ -46,21 +47,18 @@ def test_webdatamodule_setup_dataset(split, create_webdatamodule, m.setup("fit") m.setup("test") L.seed_everything(2823828) - tensors= [] + tensors = [] for sample in m._dataset[split]: - assert isinstance(sample, torch.Tensor),\ - "Sample yield from dataset is not tensor" + assert isinstance(sample, torch.Tensor), "Sample yield from dataset is not tensor" tensors.append(sample) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataset" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), - torch.vstack(lists_tensors[1])) + torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) @pytest.mark.parametrize("split", list(Split)) -def test_webdatamodule_setup_dataloader(split, create_webdatamodule, - create_another_webdatamodule): +def test_webdatamodule_setup_dataloader(split, create_webdatamodule, create_another_webdatamodule): data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] lists_tensors = [] for m in data_modules: @@ -83,14 +81,12 @@ def test_webdatamodule_setup_dataloader(split, create_webdatamodule, assert loader is not None, "dataloader not instantated" for samples in loader: # PyG's HeteroDataBatch is Batch inherited from HeteroData - assert isinstance(samples, torch.Tensor),\ - "Sample object is not torch.Tensor" + assert isinstance(samples, torch.Tensor), "Sample object is not torch.Tensor" tensors.append(samples) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataloader" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), - torch.vstack(lists_tensors[1])) + torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) class Stage(Enum): @@ -101,9 +97,9 @@ class Stage(Enum): @pytest.mark.parametrize("stage", list(Stage)) -def test_webdatamodule_in_lightning(stage, create_webdatamodule, - create_another_webdatamodule, - create_trainer_and_model): +def test_webdatamodule_in_lightning( + stage, create_webdatamodule, create_another_webdatamodule, create_trainer_and_model +): data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] trainer, model = create_trainer_and_model # get the list of samples from the loader @@ -123,7 +119,7 @@ def test_webdatamodule_in_lightning(stage, create_webdatamodule, # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() - samples = [ sample.name for sample in loader ] + samples = [sample.name for sample in loader] L.seed_everything(2823828) workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) @@ -134,9 +130,7 @@ def test_webdatamodule_in_lightning(stage, create_webdatamodule, def test_pickleddatawds_init(split, create_pickleddatawds): data_module, prefix_dir_tars_wds = create_pickleddatawds assert data_module._n_samples[split] == 10, ( - f"Wrong {split}-set size: " - f"expected 10 " - f"but got {data_module._n_samples[split]}" + f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" ) name_split = str(split).split(".")[-1] assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( @@ -145,9 +139,9 @@ def test_pickleddatawds_init(split, create_pickleddatawds): f"but got {data_module._dirs_tars_wds[split]}" ) + @pytest.mark.parametrize("split", list(Split)) -def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, - create_another_pickleddatawds): +def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, create_another_pickleddatawds): data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]] lists_tensors = [] for m in data_modules: @@ -157,13 +151,11 @@ def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, m.setup("fit") m.setup("test") L.seed_everything(2823828) - tensors= [] + tensors = [] for sample in m._dataset[split]: - assert isinstance(sample, torch.Tensor),\ - "Sample yield from dataset is not tensor" + assert isinstance(sample, torch.Tensor), "Sample yield from dataset is not tensor" tensors.append(sample) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataset" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), - torch.vstack(lists_tensors[1])) + torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) From 009fa232ea5173749cb7bafed2a37d4ce78048d4 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 00:31:28 +0000 Subject: [PATCH 57/70] BugFix: ignore E731 and fix format --- sub-packages/bionemo-webdatamodule/README.md | 3 +- .../bionemo-webdatamodule/pyproject.toml | 3 + .../src/bionemo/webdatamodule/datamodule.py | 50 +++++++++---- .../src/bionemo/webdatamodule/utils.py | 20 ++++-- .../tests/bionemo/webdatamodule/conftest.py | 71 +++++++++++++++---- .../webdatamodule/test_webdatamodule.py | 44 +++++++++--- 6 files changed, 148 insertions(+), 43 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/README.md b/sub-packages/bionemo-webdatamodule/README.md index 2c0cdff631..b06442c66a 100644 --- a/sub-packages/bionemo-webdatamodule/README.md +++ b/sub-packages/bionemo-webdatamodule/README.md @@ -328,7 +328,7 @@ constructor and test sets will be suffixed with "train", "val" and "test" respectively. - `*args` - arguments passed to the parent WebDataModule - + Kwargs: - `n_tars_wds` _int_ - attempt to create at least this number of webdataset shards @@ -351,4 +351,3 @@ webdataset tar archive with the dict structure: {"__key__" : name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. Returns: None - diff --git a/sub-packages/bionemo-webdatamodule/pyproject.toml b/sub-packages/bionemo-webdatamodule/pyproject.toml index 1f3be3632d..30a44da5c5 100644 --- a/sub-packages/bionemo-webdatamodule/pyproject.toml +++ b/sub-packages/bionemo-webdatamodule/pyproject.toml @@ -27,3 +27,6 @@ test = [ [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +lint.ignore = ["C901", "E741", "E501", "E731"] diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py index 80b938e969..73e5cc6e40 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -144,8 +144,12 @@ def __init__( suffix_keys_wds: Union[str, Iterable[str]], global_batch_size: int, prefix_tars_wds: str = "wdshards", - pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, - pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None, + pipeline_wds: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, + pipeline_prebatch_wld: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None, ): @@ -203,7 +207,9 @@ def __init__( if n_samples.keys() != keys_subset: raise RuntimeError( - f"Input n_samples has different keys than " f"dirs_tars_wds: {n_samples.keys()} vs " f"{keys_subset}" + f"Input n_samples has different keys than " + f"dirs_tars_wds: {n_samples.keys()} vs " + f"{keys_subset}" ) self._n_samples = n_samples @@ -245,14 +251,23 @@ def _setup_wds(self, split: Split) -> wds.WebDataset: """ if split not in self._dirs_tars_wds.keys(): - raise RuntimeError(f"_setup_wds() is called with {split} " f"split that doesn't have the input tar dir") - urls = sorted(glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar")) + raise RuntimeError( + f"_setup_wds() is called with {split} " + f"split that doesn't have the input tar dir" + ) + urls = sorted( + glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") + ) kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None - dataset = wds.WebDataset(urls, **(kwargs if kwargs is not None else {})).decode() + dataset = wds.WebDataset( + urls, **(kwargs if kwargs is not None else {}) + ).decode() if isinstance(self._suffix_keys_wds, str): dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}") else: - dataset = dataset.extract_keys(*[f"*.{key}" for key in self._suffix_keys_wds]) + dataset = dataset.extract_keys( + *[f"*.{key}" for key in self._suffix_keys_wds] + ) if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: if isinstance(self._pipeline_wds[split], Iterable): @@ -280,7 +295,9 @@ def setup(self, stage: str) -> None: elif stage == "predict": self._dataset[Split.test] = self._setup_wds(Split.test) else: - raise NotImplementedError(f"Data setup with stage = {stage} " f"is not implmented") + raise NotImplementedError( + f"Data setup with stage = {stage} " f"is not implmented" + ) def _setup_dataloader(self, split: Split) -> wds.WebLoader: """setup the dataloader for the input dataset split @@ -293,15 +310,21 @@ def _setup_dataloader(self, split: Split) -> wds.WebLoader: """ if self._dataset[split] is None: raise RuntimeError( - f"_setup_dataloader() is called with {split} " f"split without setting up the corresp. dataset" + f"_setup_dataloader() is called with {split} " + f"split without setting up the corresp. dataset" ) dataset = self._dataset[split] n_samples = self._n_samples[split] n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None - loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {})) + loader = wds.WebLoader( + dataset, batch_size=None, **(kwargs if kwargs is not None else {}) + ) - if self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None: + if ( + self._pipeline_prebatch_wld is not None + and self._pipeline_prebatch_wld[split] is not None + ): if isinstance(self._pipeline_prebatch_wld[split], Iterable): loader = loader.compose(*self._pipeline_prebatch_wld[split]) else: @@ -437,7 +460,10 @@ def __init__( """ super().__init__( - {split: f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" for split in names_subset.keys()}, + { + split: f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" + for split in names_subset.keys() + }, {split: len(names_subset[split]) for split in names_subset.keys()}, suffix_pickles, *args, diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index 88d61bb45e..05d3ef0ca2 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -28,7 +28,9 @@ def pickles_to_tars( input_prefix_subset: List[str], dir_output: str, output_prefix: str, - func_output_data: Callable[[str, str, Any], Dict[str, Any]] = lambda prefix, suffix, data: { + func_output_data: Callable[[str, str, Any], Dict[str, Any]] = lambda prefix, + suffix, + data: { "__key__": prefix, suffix: pickle.dumps(data), }, @@ -88,19 +90,27 @@ def pickles_to_tars( total_size = 0 for name in input_prefix_subset: try: - total_size += os.stat(os.path.join(dir_input, f"{name}.{input_suffix}")).st_size + total_size += os.stat( + os.path.join(dir_input, f"{name}.{input_suffix}") + ).st_size except Exception: continue maxsize = min(total_size * 0.6 // min_num_shards, maxsize) - with wds.ShardWriter(wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777) as sink: + with wds.ShardWriter( + wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777 + ) as sink: for name in input_prefix_subset: try: - data = pickle.load(open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb")) + data = pickle.load( + open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb") + ) # the prefix name shouldn't contain any "." per webdataset's # specification sample = func_output_data(name.replace(".", "-"), input_suffix, data) except ModuleNotFoundError as e: - logging.error(f"Dependency for parsing input pickle data not " f"found: {e}") + logging.error( + f"Dependency for parsing input pickle data not " f"found: {e}" + ) raise e except Exception as e: logging.error(f"Failed to write {name} into tar files due to error {e}") diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 35ff363c9d..e2e80eea3b 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -45,12 +45,21 @@ def gen_test_data(tmp_path_factory): t = torch.tensor(i, dtype=torch.int32) pickle.dump(t, open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb")) # generate the tars - pickles_to_tars(dir_pickles, suffix_sample, prefix_subset, dir_tars, prefix_tar, min_num_shards=3) + pickles_to_tars( + dir_pickles, + suffix_sample, + prefix_subset, + dir_tars, + prefix_tar, + min_num_shards=3, + ) return (dir_pickles, dir_tars, prefix_sample, suffix_sample, prefix_tar, n_samples) def _create_webdatamodule(gen_test_data): - (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = gen_test_data + (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = ( + gen_test_data + ) local_batch_size = 2 global_batch_size = 2 seed_rng_shfl = 82838392 @@ -59,24 +68,36 @@ def _create_webdatamodule(gen_test_data): n_samples = {split: n_samples_in_tar for split in Split} - batch = batched(local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)) + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) untuple = lambda source: (sample for (sample,) in source) pipeline_wds = { - Split.train: [untuple, shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl))], + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], Split.val: untuple, Split.test: untuple, } pipeline_prebatch_wld = { - Split.train: [shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), batch], + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], Split.val: batch, Split.test: batch, } kwargs_wds = { - split: {"shardshuffle": split == Split.train, "nodesplitter": wds.split_by_node, "seed": seed_rng_shfl} + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } for split in Split } @@ -139,13 +160,22 @@ def configure_optimizers(self): @pytest.fixture(scope="function") def create_trainer_and_model(): - trainer = L.Trainer(max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1) + trainer = L.Trainer( + max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1 + ) model = ModelTestWebDataModule() return trainer, model def _create_pickleddatawds(tmp_path_factory, gen_test_data): - (dir_pickles, _, prefix_sample, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = gen_test_data + ( + dir_pickles, + _, + prefix_sample, + suffix_keys_wds, + prefix_tars_wds, + n_samples_in_tar, + ) = gen_test_data local_batch_size = 2 global_batch_size = 2 seed_rng_shfl = 82838392 @@ -153,28 +183,43 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() - names = {split: [f"{prefix_sample}-{i:04d}" for i in range(n_samples_in_tar)] for split in Split} + names = { + split: [f"{prefix_sample}-{i:04d}" for i in range(n_samples_in_tar)] + for split in Split + } n_samples = {split: n_samples_in_tar for split in Split} - batch = batched(local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)) + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) untuple = lambda source: (sample for (sample,) in source) pipeline_wds = { - Split.train: [untuple, shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl))], + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], Split.val: untuple, Split.test: untuple, } pipeline_prebatch_wld = { - Split.train: [shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), batch], + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], Split.val: batch, Split.test: batch, } kwargs_wds = { - split: {"shardshuffle": split == Split.train, "nodesplitter": wds.split_by_node, "seed": seed_rng_shfl} + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } for split in Split } diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index bc82cc3202..1977196316 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -27,7 +27,9 @@ def test_webdatamodule_init(split, create_webdatamodule): data_module, prefix_dir_tars_wds = create_webdatamodule assert data_module._n_samples[split] == 10, ( - f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" ) assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}", ( f"Wrong tar files directory: " @@ -37,7 +39,9 @@ def test_webdatamodule_init(split, create_webdatamodule): @pytest.mark.parametrize("split", list(Split)) -def test_webdatamodule_setup_dataset(split, create_webdatamodule, create_another_webdatamodule): +def test_webdatamodule_setup_dataset( + split, create_webdatamodule, create_another_webdatamodule +): data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] lists_tensors = [] for m in data_modules: @@ -49,16 +53,22 @@ def test_webdatamodule_setup_dataset(split, create_webdatamodule, create_another L.seed_everything(2823828) tensors = [] for sample in m._dataset[split]: - assert isinstance(sample, torch.Tensor), "Sample yield from dataset is not tensor" + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" tensors.append(sample) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataset" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) @pytest.mark.parametrize("split", list(Split)) -def test_webdatamodule_setup_dataloader(split, create_webdatamodule, create_another_webdatamodule): +def test_webdatamodule_setup_dataloader( + split, create_webdatamodule, create_another_webdatamodule +): data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] lists_tensors = [] for m in data_modules: @@ -81,12 +91,16 @@ def test_webdatamodule_setup_dataloader(split, create_webdatamodule, create_anot assert loader is not None, "dataloader not instantated" for samples in loader: # PyG's HeteroDataBatch is Batch inherited from HeteroData - assert isinstance(samples, torch.Tensor), "Sample object is not torch.Tensor" + assert isinstance( + samples, torch.Tensor + ), "Sample object is not torch.Tensor" tensors.append(samples) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataloader" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) class Stage(Enum): @@ -130,7 +144,9 @@ def test_webdatamodule_in_lightning( def test_pickleddatawds_init(split, create_pickleddatawds): data_module, prefix_dir_tars_wds = create_pickleddatawds assert data_module._n_samples[split] == 10, ( - f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" ) name_split = str(split).split(".")[-1] assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( @@ -141,7 +157,9 @@ def test_pickleddatawds_init(split, create_pickleddatawds): @pytest.mark.parametrize("split", list(Split)) -def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, create_another_pickleddatawds): +def test_pickleddatawds_setup_dataset( + split, create_pickleddatawds, create_another_pickleddatawds +): data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]] lists_tensors = [] for m in data_modules: @@ -153,9 +171,13 @@ def test_pickleddatawds_setup_dataset(split, create_pickleddatawds, create_anoth L.seed_everything(2823828) tensors = [] for sample in m._dataset[split]: - assert isinstance(sample, torch.Tensor), "Sample yield from dataset is not tensor" + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" tensors.append(sample) lists_tensors.append(tensors) assert len(lists_tensors[0]) > 0, "No names in {split} dataset" - torch.testing.assert_close(torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])) + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) From 5d01a19fd050c1741e22b8033656628dc3e9c6dd Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 19:37:20 +0000 Subject: [PATCH 58/70] Regress: use PEP604 --- scripts/license_check.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/license_check.py b/scripts/license_check.py index ea681f5c9e..631285cce6 100644 --- a/scripts/license_check.py +++ b/scripts/license_check.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple import click @@ -66,7 +66,7 @@ def __str__(self) -> str: return f"{self.pyfile.name} does not have the license header!" -LicenseCheckError = Union[IOError, SyntaxError, HeaderNotFound] +LicenseCheckError = IOError | SyntaxError | HeaderNotFound """Errors that can be encountered during the license check process. Specific errors and their underlying causes: From 6bcccb6ba1c20dae6eebf698f1ea08e6cb538695 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 23:11:53 +0000 Subject: [PATCH 59/70] Enhancement: pickle within fh context --- .../src/bionemo/webdatamodule/utils.py | 7 ++++--- .../tests/bionemo/webdatamodule/conftest.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index 05d3ef0ca2..24dee9e42c 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -101,9 +101,10 @@ def pickles_to_tars( ) as sink: for name in input_prefix_subset: try: - data = pickle.load( - open(os.path.join(dir_input, f"{name}.{input_suffix}"), "rb") - ) + with open( + os.path.join(dir_input, f"{name}.{input_suffix}"), "rb" + ) as fh: + data = pickle.load(fh) # the prefix name shouldn't contain any "." per webdataset's # specification sample = func_output_data(name.replace(".", "-"), input_suffix, data) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index e2e80eea3b..e63cd36e80 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -43,7 +43,8 @@ def gen_test_data(tmp_path_factory): prefix = f"{prefix_sample}-{i:04}" prefix_subset.append(prefix) t = torch.tensor(i, dtype=torch.int32) - pickle.dump(t, open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb")) + with open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb") as fh: + pickle.dump(t, fh) # generate the tars pickles_to_tars( dir_pickles, From 168bc9bb11df72a37b458e8449b57f1678d996f3 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 23:15:33 +0000 Subject: [PATCH 60/70] Bugfix: default to python 3.10 for consistency --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80dce3eda6..06f0e784b4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +default_language_version: + python: python3.10 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 From 4c611ba9ada662adbb1d8337fc6586433f19500b Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Tue, 20 Aug 2024 23:21:19 +0000 Subject: [PATCH 61/70] BugFix: f-string typo --- .../bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index 24dee9e42c..79ac020253 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -110,7 +110,7 @@ def pickles_to_tars( sample = func_output_data(name.replace(".", "-"), input_suffix, data) except ModuleNotFoundError as e: logging.error( - f"Dependency for parsing input pickle data not " f"found: {e}" + f"Dependency for parsing input pickle data not found: {e}" ) raise e except Exception as e: From 86f97e34c189153b3d12db01b78073c32b07905a Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 23 Aug 2024 18:09:22 +0000 Subject: [PATCH 62/70] Enhancement: use sample counts instead of size to enforce number of tar created ... and add a test for it --- .../src/bionemo/webdatamodule/utils.py | 28 +++++++------------ .../tests/bionemo/webdatamodule/conftest.py | 2 +- .../webdatamodule/test_webdatamodule.py | 17 +++++++++-- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index 79ac020253..f4c5e50175 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -79,25 +79,17 @@ def pickles_to_tars( """ os.makedirs(dir_output, exist_ok=True) wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") - maxsize = 1e8 - # Due to a Webdataset bug, number of shards should be >= number of workers - # (num. of gpus * num. of workers per gpu) - # TODO: this algorithm is not accurate enough because it doesn't take into - # account the block structure so I have to multiply the total_size with a - # small prefactor to purposely underestimate the size so that it ends up - # creating more tar files than min_num_shards - if min_num_shards is not None and min_num_shards > 1: - total_size = 0 - for name in input_prefix_subset: - try: - total_size += os.stat( - os.path.join(dir_input, f"{name}.{input_suffix}") - ).st_size - except Exception: - continue - maxsize = min(total_size * 0.6 // min_num_shards, maxsize) + n_samples_per_shard_max = 100000 + if min_num_shards is not None: + if min_num_shards <= 0: + raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0") + n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards with wds.ShardWriter( - wd_subset_pattern, encoder=False, maxsize=maxsize, compress=False, mode=0o777 + wd_subset_pattern, + encoder=False, + maxcount=n_samples_per_shard_max, + compress=False, + mode=0o777, ) as sink: for name in input_prefix_subset: try: diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index e63cd36e80..e5e58eb095 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -240,7 +240,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): kwargs_wld=kwargs_wld, ) - return data_module, prefix_dir_tars_wds + return data_module, prefix_dir_tars_wds, n_tars_wds @pytest.fixture(scope="module") diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index 1977196316..c20c9345e4 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import glob from enum import Enum, auto import lightning as L @@ -142,7 +142,7 @@ def test_webdatamodule_in_lightning( @pytest.mark.parametrize("split", list(Split)) def test_pickleddatawds_init(split, create_pickleddatawds): - data_module, prefix_dir_tars_wds = create_pickleddatawds + data_module, prefix_dir_tars_wds, _ = create_pickleddatawds assert data_module._n_samples[split] == 10, ( f"Wrong {split}-set size: " f"expected 10 " @@ -156,6 +156,19 @@ def test_pickleddatawds_init(split, create_pickleddatawds): ) +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_prepare_data(split, create_pickleddatawds): + data_module, _, n_tars_min = create_pickleddatawds + data_module.prepare_data() + dir_tars = f"{data_module._dirs_tars_wds[split]}" + tars = glob.glob(f"{dir_tars}/{data_module._prefix_tars_wds}-*.tar") + n_tars = len(tars) + assert n_tars_min <= n_tars and n_tars <= n_tars_min + 1, ( + f"Number of tar files: {n_tars} in {dir_tars} is outside the range " + f"[{n_tars_min}, {n_tars_min + 1}]" + ) + + @pytest.mark.parametrize("split", list(Split)) def test_pickleddatawds_setup_dataset( split, create_pickleddatawds, create_another_pickleddatawds From 553258dad957e8131368b0c877c0b819337e9a29 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Fri, 23 Aug 2024 22:19:14 +0000 Subject: [PATCH 63/70] Test: sample overlap tests via md5 hash of pickles ... per the requirement of the FW v2 alpha release --- .../tests/bionemo/webdatamodule/conftest.py | 69 ++++++++++++------- .../webdatamodule/test_webdatamodule.py | 31 ++++++++- 2 files changed, 73 insertions(+), 27 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index e5e58eb095..8eb742a5c2 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -31,44 +31,71 @@ @pytest.fixture(scope="module") def gen_test_data(tmp_path_factory): dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() - dir_tars = tmp_path_factory.mktemp("webdatamodule").as_posix() + dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() + dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} prefix_sample = "sample" suffix_sample = "tensor.pyd" prefix_tar = "tensor" - n_samples = 10 + n_samples_per_split = 10 + n_samples = {split: n_samples_per_split for split in Split} os.makedirs(dir_pickles, exist_ok=True) - prefix_subset = [] - # generate the pickles - for i in range(n_samples): + prefixes = [] + # generate the pickles for train, val, and test + for i in range(n_samples_per_split * 3): prefix = f"{prefix_sample}-{i:04}" - prefix_subset.append(prefix) + prefixes.append(prefix) t = torch.tensor(i, dtype=torch.int32) with open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb") as fh: pickle.dump(t, fh) + prefixes_pickle = { + Split.train: prefixes[0:n_samples_per_split], + Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], + Split.test: prefixes[n_samples_per_split * 2 : n_samples_per_split * 3], + } # generate the tars pickles_to_tars( dir_pickles, suffix_sample, - prefix_subset, - dir_tars, + prefixes_pickle[Split.train], + dir_tars[Split.train], + prefix_tar, + min_num_shards=3, + ) + pickles_to_tars( + dir_pickles, + suffix_sample, + prefixes_pickle[Split.val], + dir_tars[Split.val], prefix_tar, min_num_shards=3, ) - return (dir_pickles, dir_tars, prefix_sample, suffix_sample, prefix_tar, n_samples) + pickles_to_tars( + dir_pickles, + suffix_sample, + prefixes_pickle[Split.test], + dir_tars[Split.test], + prefix_tar, + min_num_shards=3, + ) + return ( + dir_pickles, + dir_tars, + prefix_sample, + suffix_sample, + prefix_tar, + n_samples, + prefixes_pickle, + ) def _create_webdatamodule(gen_test_data): - (_, dir_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples_in_tar) = ( + (_, dirs_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples, _) = ( gen_test_data ) local_batch_size = 2 global_batch_size = 2 seed_rng_shfl = 82838392 - dirs_tars_wds = {split: dir_tars_wds for split in Split} - - n_samples = {split: n_samples_in_tar for split in Split} - batch = batched( local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) @@ -116,7 +143,7 @@ def _create_webdatamodule(gen_test_data): kwargs_wld=kwargs_wld, ) - return data_module, dir_tars_wds + return data_module, dirs_tars_wds @pytest.fixture(scope="module") @@ -172,10 +199,11 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): ( dir_pickles, _, - prefix_sample, + _, suffix_keys_wds, prefix_tars_wds, - n_samples_in_tar, + n_samples, + names, ) = gen_test_data local_batch_size = 2 global_batch_size = 2 @@ -184,13 +212,6 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() - names = { - split: [f"{prefix_sample}-{i:04d}" for i in range(n_samples_in_tar)] - for split in Split - } - - n_samples = {split: n_samples_in_tar for split in Split} - batch = batched( local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index c20c9345e4..0412e1fbf2 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -25,15 +25,15 @@ @pytest.mark.parametrize("split", list(Split)) def test_webdatamodule_init(split, create_webdatamodule): - data_module, prefix_dir_tars_wds = create_webdatamodule + data_module, dirs_tars_wds = create_webdatamodule assert data_module._n_samples[split] == 10, ( f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" ) - assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}", ( + assert data_module._dirs_tars_wds[split] == f"{dirs_tars_wds[split]}", ( f"Wrong tar files directory: " - f"expected {prefix_dir_tars_wds} " + f"expected {dirs_tars_wds[split]} " f"but got {data_module._dirs_tars_wds[split]}" ) @@ -194,3 +194,28 @@ def test_pickleddatawds_setup_dataset( torch.testing.assert_close( torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) ) + + +def test_pickleddatawds_sample_overlap(create_pickleddatawds): + data_module = create_pickleddatawds[0] + # this writes the tar files to disk + data_module.prepare_data() + # read the data back by setting up the dataset object and loop over it + data_module.setup("fit") + data_module.setup("test") + results = { + split: set([sample.item() for sample in data_module._dataset[split]]) + for split in Split + } + overlap_train_val = results[Split.train] & results[Split.val] + overlap_train_test = results[Split.train] & results[Split.test] + overlap_val_test = results[Split.val] & results[Split.test] + assert ( + len(overlap_train_val) == 0 + ), "Shared samples found between train and val datasets" + assert ( + len(overlap_train_test) == 0 + ), "Shared samples found between train and test datasets" + assert ( + len(overlap_val_test) == 0 + ), "Shared samples found between val and test datasets" From d5284e6c285504569848f957c03dcd9101786e8c Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 24 Aug 2024 00:07:16 +0000 Subject: [PATCH 64/70] Test: assert throw from webdataset when num_workers > num_shards --- .../tests/bionemo/webdatamodule/conftest.py | 9 +++- .../webdatamodule/test_webdatamodule.py | 45 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 8eb742a5c2..a7c5e03083 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -88,7 +88,7 @@ def gen_test_data(tmp_path_factory): ) -def _create_webdatamodule(gen_test_data): +def _create_webdatamodule(gen_test_data, num_workers=2): (_, dirs_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples, _) = ( gen_test_data ) @@ -129,7 +129,7 @@ def _create_webdatamodule(gen_test_data): for split in Split } - kwargs_wld = {split: {"num_workers": 2} for split in Split} + kwargs_wld = {split: {"num_workers": num_workers} for split in Split} data_module = WebDataModule( dirs_tars_wds, @@ -156,6 +156,11 @@ def create_another_webdatamodule(gen_test_data): return _create_webdatamodule(gen_test_data) +@pytest.fixture(scope="module") +def create_webdatamodule_with_5_workers(gen_test_data): + return _create_webdatamodule(gen_test_data, num_workers=5) + + class ModelTestWebDataModule(L.LightningModule): def __init__(self) -> None: super().__init__() diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index 0412e1fbf2..07ab45e38a 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -103,6 +103,51 @@ def test_webdatamodule_setup_dataloader( ) +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_throw_on_many_workers( + split, create_webdatamodule_with_5_workers +): + data_module = create_webdatamodule_with_5_workers[0] + urls = glob.glob( + f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar" + ) + n_tars = len(urls) + data_module._kwargs_wld[split]["num_workers"] = n_tars + 1 + data_module.prepare_data() + data_module.setup("fit") + data_module.setup("test") + loader = None + if split == Split.train: + loader = data_module.train_dataloader() + elif split == Split.val: + loader = data_module.val_dataloader() + elif split == Split.test: + loader = data_module.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + try: + for _ in loader: + pass + except ValueError as e: + # this is expected + assert "have fewer shards than workers" in str(e), ( + f"'have fewer shards than workers' not found in exception " + f"raised from data loading: {e}" + ) + except Exception as e: + raise RuntimeError( + f"WebLoader doesn't raise ValueError with fewer " + f"shards than workers but raise this instead: {e}" + ) + else: + raise NotImplementedError( + "WebLoader doesn't throw error with num_workers > num_shards " + "User should report this issue to webdataset and create " + "less shards than workers in practice as a workaround" + ) + + class Stage(Enum): fit = auto() validate = auto() From 6e51a7e7faef11660d811c251569176e132435ef Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 28 Aug 2024 18:15:04 +0000 Subject: [PATCH 65/70] Enhancement: reuse the webdatamodule dirs_tars_wds arg in PickledDataWDS --- .../src/bionemo/webdatamodule/datamodule.py | 28 ++++++++----------- .../tests/bionemo/webdatamodule/conftest.py | 7 +++-- .../webdatamodule/test_webdatamodule.py | 7 ++--- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py index 73e5cc6e40..6ba7854c75 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -126,8 +126,8 @@ class WebDataModule(L.LightningDataModule): >>> } >>> >>> # construct the data module - >>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, - global_batch_size, + >>> data_module = WebDataModule(n_samples, suffix_keys_wds, + dirs_of_tar_files, global_batch_size, prefix_tars_wds=tar_file_prefix, pipeline_wds=pipeline_wds, pipeline_prebatch_wld=pipeline_prebatch_wld, @@ -139,9 +139,9 @@ class WebDataModule(L.LightningDataModule): def __init__( self, - dirs_tars_wds: Dict[Split, str], n_samples: Dict[Split, int], suffix_keys_wds: Union[str, Iterable[str]], + dirs_tars_wds: Dict[Split, str], global_batch_size: int, prefix_tars_wds: str = "wdshards", pipeline_wds: Optional[ @@ -156,14 +156,14 @@ def __init__( """constructor Args: - dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file - directory that contains the webdataset tar files for each split n_samples (Dict[Split, int]): input dictionary: Split -> number of data samples for each split suffix_keys_wds (Union[str, Iterable[str]]): a set of keys each corresponding to a data object in the webdataset tar file dictionary. The data objects of these keys will be extracted and tupled for each sample in the tar files + dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split global_batch_size (int): size of batch summing across nodes in Data Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: this data module doesn't rely on the input `global_batch_size` @@ -391,7 +391,11 @@ class PickledDataWDS(WebDataModule): >>> n_tars_wds = 5 >>> prefix_tars_wds = "myshards" - >>> output_dir_tar_files = "/path/to/output/tars/dir" + >>> output_dir_tar_files = { + Split.train : "/path/to/output/tars/dir-train", + Split.val : "/path/to/output/tars/dir-val", + Split.test : "/path/to/output/tars/dir-test", + } >>> # see the `WebDataModule` API doc for the definition of global_batch_size >>> global_batch_size = 16 @@ -412,7 +416,7 @@ class PickledDataWDS(WebDataModule): >>> dir_pickles, >>> suffix_pickles, >>> names_subset, - >>> output_dir_tar_files, + >>> output_dir_tar_files, # `WebDataModule` args >>> global_batch_size, # `WebDataModule` args >>> n_tars_wds=n_tars_wds, >>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs @@ -431,7 +435,6 @@ def __init__( dir_pickles: str, suffix_pickles: str, names_subset: Dict[Split, List[str]], - prefix_dir_tars_wds: str, *args, n_tars_wds: Optional[int] = None, **kwargs, @@ -446,10 +449,6 @@ def __init__( names_subset (Dict[Split, List[str]]): list of filename prefix of the data samples to be loaded in the dataset and dataloader for each of the split - prefix_dir_tars_wds (str): directory name prefix to store the output - webdataset tar files. The actual directories storing the train, val - and test sets will be suffixed with "train", "val" and "test" - respectively. *args: arguments passed to the parent WebDataModule Kwargs: @@ -460,10 +459,6 @@ def __init__( """ super().__init__( - { - split: f"{prefix_dir_tars_wds}{str(split).split('.')[-1]}" - for split in names_subset.keys() - }, {split: len(names_subset[split]) for split in names_subset.keys()}, suffix_pickles, *args, @@ -472,7 +467,6 @@ def __init__( self._dir_pickles = dir_pickles self._suffix_pickles = suffix_pickles - self._prefix_dir_tars_wds = prefix_dir_tars_wds self._names_subset = names_subset diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index a7c5e03083..b1c9411fd8 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -132,9 +132,9 @@ def _create_webdatamodule(gen_test_data, num_workers=2): kwargs_wld = {split: {"num_workers": num_workers} for split in Split} data_module = WebDataModule( - dirs_tars_wds, n_samples, suffix_keys_wds, + dirs_tars_wds, global_batch_size, prefix_tars_wds=prefix_tars_wds, pipeline_wds=pipeline_wds, @@ -216,6 +216,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): n_tars_wds = 3 prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() + dirs_tars_wds = {s: f"{prefix_dir_tars_wds}{str(s).split('.')[-1]}" for s in Split} batch = batched( local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) @@ -256,7 +257,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): dir_pickles, suffix_keys_wds, names, - prefix_dir_tars_wds, + dirs_tars_wds, global_batch_size, n_tars_wds=n_tars_wds, prefix_tars_wds=prefix_tars_wds, @@ -266,7 +267,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): kwargs_wld=kwargs_wld, ) - return data_module, prefix_dir_tars_wds, n_tars_wds + return data_module, dirs_tars_wds, n_tars_wds @pytest.fixture(scope="module") diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py index 07ab45e38a..fd5831585b 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py @@ -187,16 +187,15 @@ def test_webdatamodule_in_lightning( @pytest.mark.parametrize("split", list(Split)) def test_pickleddatawds_init(split, create_pickleddatawds): - data_module, prefix_dir_tars_wds, _ = create_pickleddatawds + data_module, dirs_tars_wds, _ = create_pickleddatawds assert data_module._n_samples[split] == 10, ( f"Wrong {split}-set size: " f"expected 10 " f"but got {data_module._n_samples[split]}" ) - name_split = str(split).split(".")[-1] - assert data_module._dirs_tars_wds[split] == f"{prefix_dir_tars_wds}{name_split}", ( + assert data_module._dirs_tars_wds[split] == dirs_tars_wds[split], ( f"Wrong tar files directory: " - f"expected {prefix_dir_tars_wds}{name_split} " + f"expected {dirs_tars_wds[split]} " f"but got {data_module._dirs_tars_wds[split]}" ) From 5b6f3091d7be363f809f92061a568eda93bf9f43 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 28 Aug 2024 23:37:43 +0000 Subject: [PATCH 66/70] Enhancement: allow multiple data objects for each sample --- .../src/bionemo/webdatamodule/datamodule.py | 16 ++--- .../src/bionemo/webdatamodule/utils.py | 70 ++++++++++++------- .../tests/bionemo/webdatamodule/conftest.py | 8 +-- 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py index 6ba7854c75..33fa7936b9 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -216,7 +216,7 @@ def __init__( self._global_batch_size = global_batch_size - if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable[str]])): + if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])): raise TypeError("suffix_keys_wds can only be str or Iterable[str]") self._suffix_keys_wds = suffix_keys_wds @@ -414,8 +414,8 @@ class PickledDataWDS(WebDataModule): >>> # create the data module >>> data_module = PickledDataWDS( >>> dir_pickles, - >>> suffix_pickles, >>> names_subset, + >>> suffix_pickles, # `WebDataModule` args >>> output_dir_tar_files, # `WebDataModule` args >>> global_batch_size, # `WebDataModule` args >>> n_tars_wds=n_tars_wds, @@ -433,7 +433,6 @@ class PickledDataWDS(WebDataModule): def __init__( self, dir_pickles: str, - suffix_pickles: str, names_subset: Dict[Split, List[str]], *args, n_tars_wds: Optional[int] = None, @@ -443,13 +442,12 @@ def __init__( Args: dir_pickles (str): input directory of pickled data files - suffix_pickles (str): filename suffix of the input data in - dir_pickles. This is also used as the key mapped to the - tarballed pickled object in the webdataset names_subset (Dict[Split, List[str]]): list of filename prefix of the data samples to be loaded in the dataset and dataloader for each of the split - *args: arguments passed to the parent WebDataModule + *args: arguments passed to the parent WebDataModule after its + `n_samples` args (where `n_samples` is deduced from the length of + `names_subset` arg of this class) Kwargs: n_tars_wds (int): attempt to create at least this number of @@ -460,13 +458,11 @@ def __init__( """ super().__init__( {split: len(names_subset[split]) for split in names_subset.keys()}, - suffix_pickles, *args, **kwargs, ) self._dir_pickles = dir_pickles - self._suffix_pickles = suffix_pickles self._names_subset = names_subset @@ -487,8 +483,8 @@ def prepare_data(self) -> None: # create wds shards (tar files) for train set pickles_to_tars( self._dir_pickles, - self._suffix_pickles, self._names_subset[split], + self._suffix_keys_wds, self._dirs_tars_wds[split], self._prefix_tars_wds, min_num_shards=self._n_tars_wds, diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py index f4c5e50175..541957edd7 100644 --- a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -16,7 +16,8 @@ import os import pickle -from typing import Any, Callable, Dict, List, Optional +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_args import webdataset as wds from nemo.utils import logging @@ -24,21 +25,22 @@ def pickles_to_tars( dir_input: str, - input_suffix: str, input_prefix_subset: List[str], + input_suffix: Union[str, Iterable[str]], dir_output: str, output_prefix: str, - func_output_data: Callable[[str, str, Any], Dict[str, Any]] = lambda prefix, - suffix, - data: { - "__key__": prefix, - suffix: pickle.dumps(data), - }, + func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix, + suffix_to_data: {"__key__": prefix, **suffix_to_data}, min_num_shards: Optional[int] = None, ) -> None: """Convert a subset of pickle files from a directory to Webdataset tar files - Input path and name pattern: - f"{dir_input}/{input_prefix_subset}.{input_suffix}" + Input path and name pattern for sample 0: + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" + Input path and name pattern for sample 1: + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" + ... Output path and name pattern: f"{dir_output}/{output_prefix}-%06d.tar" @@ -52,21 +54,21 @@ def pickles_to_tars( so that parsing the tar archive is equivalent of reading {sample_filename_preifx}.{sample_filename_suffix_1} etc. - Here, the assumption is that there is only one sample data file, whose name - prefix is given in each of the elements of `input_prefix_subset` and whose - name suffix is given by `input_suffix`. Per the webdataset file format - specification, the `sample_filename_preifx` can't contain dots '.' so this - function removes it for the user by calling .replace(".", "-") on the - elements of `input_prefix_subset` + Here, each sample data get its name prefix from one element of + `input_prefix_subset` and its name suffixes from the list `input_suffix`. + Per the webdataset file format specification, the `sample_filename_preifx` + can't contain dots '.' so this function removes it for the user by calling + .replace(".", "-") on the elements of `input_prefix_subset` Args: dir_input (str): Input directory - input_suffix (str): Input pickle file name suffix input_prefix_subset (List[str]): Input subset of pickle files' prefix + input_suffix (Union[str, Iterable[str]]): Input pickle file name + suffixes, each for one type of data object, for all the samples dir_output (str): Output directory output_prefix (str): Output tar file name prefix - func_output_data (Callable[[str, str, Any], Dict[str, Any]]) : function - that maps the name prefix, name suffix and data object to a + func_output_data (Callable[[str, Dict[str, Any]], Dict[str, Any]]) : + function that maps the name prefix, name suffix and data object to a webdataset tar archive dictionary. Refer to the webdataset github repo for the archive file format specification. min_num_shards (int) : create at least this number of tar files. @@ -77,6 +79,8 @@ def pickles_to_tars( Returns: None """ + if not isinstance(input_suffix, get_args(Union[str, Iterable])): + raise TypeError("input_suffix can only be str or Iterable[str]") os.makedirs(dir_output, exist_ok=True) wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") n_samples_per_shard_max = 100000 @@ -93,13 +97,29 @@ def pickles_to_tars( ) as sink: for name in input_prefix_subset: try: - with open( - os.path.join(dir_input, f"{name}.{input_suffix}"), "rb" - ) as fh: - data = pickle.load(fh) + if isinstance(input_suffix, str): + suffix_to_data = { + input_suffix: pickle.dumps( + pickle.loads( + ( + Path(dir_input) / f"{name}.{input_suffix}" + ).read_bytes() + ) + ) + } + else: + suffix_to_data = { + suffix: pickle.dumps( + pickle.loads( + (Path(dir_input) / f"{name}.{suffix}").read_bytes() + ) + ) + for suffix in input_suffix + } # the prefix name shouldn't contain any "." per webdataset's # specification - sample = func_output_data(name.replace(".", "-"), input_suffix, data) + sample = func_output_data(name.replace(".", "-"), suffix_to_data) + sink.write(sample) except ModuleNotFoundError as e: logging.error( f"Dependency for parsing input pickle data not found: {e}" @@ -108,5 +128,3 @@ def pickles_to_tars( except Exception as e: logging.error(f"Failed to write {name} into tar files due to error {e}") raise e - - sink.write(sample) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index b1c9411fd8..4e4e961e0e 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -55,24 +55,24 @@ def gen_test_data(tmp_path_factory): # generate the tars pickles_to_tars( dir_pickles, - suffix_sample, prefixes_pickle[Split.train], + suffix_sample, dir_tars[Split.train], prefix_tar, min_num_shards=3, ) pickles_to_tars( dir_pickles, - suffix_sample, prefixes_pickle[Split.val], + suffix_sample, dir_tars[Split.val], prefix_tar, min_num_shards=3, ) pickles_to_tars( dir_pickles, - suffix_sample, prefixes_pickle[Split.test], + suffix_sample, dir_tars[Split.test], prefix_tar, min_num_shards=3, @@ -255,8 +255,8 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): data_module = PickledDataWDS( dir_pickles, - suffix_keys_wds, names, + suffix_keys_wds, dirs_tars_wds, global_batch_size, n_tars_wds=n_tars_wds, From 8522273edfc384eab39493dc248fa8c488c4254c Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 28 Aug 2024 23:39:30 +0000 Subject: [PATCH 67/70] Regress: remove default python version --- .pre-commit-config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06f0e784b4..80dce3eda6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,3 @@ -default_language_version: - python: python3.10 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 From 26a59d64b8d6372cf7f73da9048c1d4edaeb66f3 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Wed, 28 Aug 2024 23:40:03 +0000 Subject: [PATCH 68/70] Test: rename test to be consistent with module --- .../webdatamodule/{test_webdatamodule.py => test_datamodule.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/{test_webdatamodule.py => test_datamodule.py} (100%) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py similarity index 100% rename from sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_webdatamodule.py rename to sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py From 0d5c144079b750fecde2d7b14e302b3efe382182 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 29 Aug 2024 20:22:08 +0000 Subject: [PATCH 69/70] Test: factor pickle generation into a new fixture --- .../tests/bionemo/webdatamodule/conftest.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 4e4e961e0e..5551d6efe4 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -14,7 +14,6 @@ # limitations under the License. -import os import pickle import random @@ -29,16 +28,11 @@ @pytest.fixture(scope="module") -def gen_test_data(tmp_path_factory): +def gen_pickle_files(tmp_path_factory): dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() - dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() - dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} prefix_sample = "sample" suffix_sample = "tensor.pyd" - prefix_tar = "tensor" n_samples_per_split = 10 - n_samples = {split: n_samples_per_split for split in Split} - os.makedirs(dir_pickles, exist_ok=True) prefixes = [] # generate the pickles for train, val, and test for i in range(n_samples_per_split * 3): @@ -52,6 +46,24 @@ def gen_test_data(tmp_path_factory): Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], Split.test: prefixes[n_samples_per_split * 2 : n_samples_per_split * 3], } + return ( + dir_pickles, + prefix_sample, + suffix_sample, + prefixes_pickle, + n_samples_per_split, + ) + + +@pytest.fixture(scope="module") +def gen_test_data(tmp_path_factory, gen_pickle_files): + dir_pickles, prefix_sample, suffix_sample, prefixes_pickle, n_samples_per_split = ( + gen_pickle_files + ) + dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() + dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} + prefix_tar = "tensor" + n_samples = {split: n_samples_per_split for split in Split} # generate the tars pickles_to_tars( dir_pickles, From a35765ff8eb9a520797be46585c5e62430ef3c77 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 29 Aug 2024 23:53:02 +0000 Subject: [PATCH 70/70] Test: multiple data objects per sample --- .../tests/bionemo/webdatamodule/conftest.py | 33 ++++++++++++------- .../bionemo/webdatamodule/test_datamodule.py | 7 ++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 5551d6efe4..a43f4c0be3 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -31,7 +31,7 @@ def gen_pickle_files(tmp_path_factory): dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() prefix_sample = "sample" - suffix_sample = "tensor.pyd" + suffix_sample = ["tensor.pyd", "tensor_copy.pyd"] n_samples_per_split = 10 prefixes = [] # generate the pickles for train, val, and test @@ -39,8 +39,9 @@ def gen_pickle_files(tmp_path_factory): prefix = f"{prefix_sample}-{i:04}" prefixes.append(prefix) t = torch.tensor(i, dtype=torch.int32) - with open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb") as fh: - pickle.dump(t, fh) + for suffix in suffix_sample: + with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh: + pickle.dump(t, fh) prefixes_pickle = { Split.train: prefixes[0:n_samples_per_split], Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], @@ -55,11 +56,16 @@ def gen_pickle_files(tmp_path_factory): ) -@pytest.fixture(scope="module") -def gen_test_data(tmp_path_factory, gen_pickle_files): - dir_pickles, prefix_sample, suffix_sample, prefixes_pickle, n_samples_per_split = ( +@pytest.fixture(scope="module", params=[1, 2]) +def gen_test_data(tmp_path_factory, gen_pickle_files, request): + dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = ( gen_pickle_files ) + n_suffixes = request.param + if n_suffixes <= 1: + suffix_sample = suffixes[0] + else: + suffix_sample = suffixes[0:n_suffixes] dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} prefix_tar = "tensor" @@ -112,7 +118,10 @@ def _create_webdatamodule(gen_test_data, num_workers=2): local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) - untuple = lambda source: (sample for (sample,) in source) + if isinstance(suffix_keys_wds, str): + untuple = lambda source: (sample[0] for sample in source) + elif isinstance(suffix_keys_wds, list): + untuple = lambda source: (torch.vstack(sample) for sample in source) pipeline_wds = { Split.train: [ @@ -183,19 +192,19 @@ def forward(self, x): return self._model(x.float()) def training_step(self, batch): - self._samples[Split.train].append(batch.name) + self._samples[Split.train].append(batch) loss = self(batch).sum() return loss def validation_step(self, batch, batch_index): - self._samples[Split.val].append(batch.name) + self._samples[Split.val].append(batch) return torch.zeros(1) def test_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) + self._samples[Split.test].append(batch) def predict_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) + self._samples[Split.test].append(batch) return torch.zeros(1) def configure_optimizers(self): @@ -234,7 +243,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) - untuple = lambda source: (sample for (sample,) in source) + untuple = lambda source: (sample[0] for sample in source) pipeline_wds = { Split.train: [ diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py index fd5831585b..692905a416 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py @@ -178,11 +178,14 @@ def test_webdatamodule_in_lightning( # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() - samples = [sample.name for sample in loader] L.seed_everything(2823828) workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) - assert model._samples[split] == samples + device = model._samples[split][0].device + samples = [sample.to(device=device) for sample in loader] + torch.testing.assert_close( + torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0) + ) @pytest.mark.parametrize("split", list(Split))