Skip to content

Commit

Permalink
HDF5TrajectoryDataset (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 authored Jul 9, 2024
1 parent 2cf4ee7 commit 438eb1d
Show file tree
Hide file tree
Showing 10 changed files with 1,948 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/beignet.datasets.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# beignet.datasets

::: beignet.datasets.FASTADataset
::: beignet.datasets.HDF5TrajectoryDataset
::: beignet.datasets.PDBTrajectoryDataset
::: beignet.datasets.RandomEulerAngleDataset
::: beignet.datasets.RandomQuaternionDataset
::: beignet.datasets.RandomRotationMatrixDataset
Expand All @@ -9,6 +11,7 @@
::: beignet.datasets.SizedSequenceDataset
::: beignet.datasets.SwissProtDataset
::: beignet.datasets.TrEMBLDataset
::: beignet.datasets.TrajectoryDataset
::: beignet.datasets.UniProtDataset
::: beignet.datasets.UniRef50Dataset
::: beignet.datasets.UniRef90Dataset
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ requires-python = ">=3.10"

[project.optional-dependencies]
all = [
"beignet[docs,test]",
"beignet[docs,mdtraj,test]",
]
docs = [
"mkdocs-material",
"mkdocstrings[python]",
]
mdtraj = [
"mdtraj",
]
test = [
"hypothesis",
"numpy==1.26.4",
Expand Down
6 changes: 6 additions & 0 deletions src/beignet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from ._fasta_dataset import FASTADataset
from ._hdf5_trajectory_dataset import HDF5TrajectoryDataset
from ._pdb_trajectory_dataset import PDBTrajectoryDataset
from ._random_euler_angle_dataset import RandomEulerAngleDataset
from ._random_quaternion_dataset import RandomQuaternionDataset
from ._random_rotation_matrix_dataset import RandomRotationMatrixDataset
from ._random_rotation_vector_dataset import RandomRotationVectorDataset
from ._sequence_dataset import SequenceDataset
from ._sized_sequence_dataset import SizedSequenceDataset
from ._swissprot_dataset import SwissProtDataset
from ._trajectory_dataset import TrajectoryDataset
from ._trembl_dataset import TrEMBLDataset
from ._uniprot_dataset import UniProtDataset
from ._uniref50_dataset import UniRef50Dataset
Expand All @@ -14,6 +17,8 @@

__all__ = [
"FASTADataset",
"HDF5TrajectoryDataset",
"PDBTrajectoryDataset",
"RandomEulerAngleDataset",
"RandomQuaternionDataset",
"RandomRotationMatrixDataset",
Expand All @@ -22,6 +27,7 @@
"SizedSequenceDataset",
"SwissProtDataset",
"TrEMBLDataset",
"TrajectoryDataset",
"UniProtDataset",
"UniRef100Dataset",
"UniRef50Dataset",
Expand Down
25 changes: 25 additions & 0 deletions src/beignet/datasets/_hdf5_trajectory_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from os import PathLike
from typing import Any, Callable

import mdtraj
from mdtraj import Trajectory

from ._trajectory_dataset import TrajectoryDataset


class HDF5TrajectoryDataset(TrajectoryDataset):
def __init__(
self,
root: str | PathLike,
transform: Callable[[Trajectory], Any] | None = None,
stride: int | None = None,
**kwargs,
):
super().__init__(
func=mdtraj.load_hdf5,
extension="hdf5",
root=root,
transform=transform,
stride=stride,
**kwargs,
)
25 changes: 25 additions & 0 deletions src/beignet/datasets/_pdb_trajectory_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from os import PathLike
from typing import Any, Callable

import mdtraj
from mdtraj import Trajectory

from ._trajectory_dataset import TrajectoryDataset


class PDBTrajectoryDataset(TrajectoryDataset):
def __init__(
self,
root: str | PathLike,
transform: Callable[[Trajectory], Any] | None = None,
stride: int | None = None,
**kwargs,
):
super().__init__(
func=mdtraj.load_pdb,
extension="pdb",
root=root,
transform=transform,
stride=stride,
**kwargs,
)
44 changes: 44 additions & 0 deletions src/beignet/datasets/_trajectory_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import functools
from os import PathLike
from pathlib import Path
from typing import Any, Callable

from mdtraj import Trajectory
from torch.utils.data import Dataset


class TrajectoryDataset(Dataset):
def __init__(
self,
func: Callable,
extension: str,
root: str | PathLike,
transform: Callable[[Trajectory], Any] | None = None,
stride: int | None = None,
**kwargs,
):
self.func = functools.partial(func, **kwargs)

if isinstance(root, str):
root = Path(root)

self.root = root.resolve()

self.transform = transform

self.stride = stride

self.paths = [*self.root.glob(f"*.{extension}")]

super().__init__()

def __getitem__(self, index: int) -> Trajectory:
item = self.func(self.paths[index], stride=self.stride)

if self.transform:
item = self.transform(item)

return item

def __len__(self) -> int:
return len(self.paths)
10 changes: 10 additions & 0 deletions tests/beignet/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os.path

import pytest


@pytest.fixture
def data_path(request):
directory = os.path.dirname(request.module.__file__)

return os.path.join(directory, "../data")
Binary file added tests/beignet/data/trajectory.hdf5
Binary file not shown.
Loading

0 comments on commit 438eb1d

Please sign in to comment.