From b92910f3314a64f305ee6977e77dca342c90a2d7 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 2 Jan 2023 15:11:25 -0500 Subject: [PATCH 1/2] move functions outside of local scope --- muspy/datasets/base.py | 194 ++++++++++++++++++++++------------------- 1 file changed, 104 insertions(+), 90 deletions(-) diff --git a/muspy/datasets/base.py b/muspy/datasets/base.py index 6ab4249..dfdd26a 100644 --- a/muspy/datasets/base.py +++ b/muspy/datasets/base.py @@ -1,6 +1,7 @@ """Base Dataset classes.""" import json import warnings +from functools import partial from pathlib import Path from typing import ( TYPE_CHECKING, @@ -34,6 +35,14 @@ from tensorflow.data import Dataset as TFDataset from torch.utils.data import Dataset as TorchDataset +try: + # pylint: disable=import-outside-toplevel + from torch.utils.data import Dataset as TorchDataset + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + RemoteDatasetT = TypeVar("RemoteDatasetT", bound="RemoteDataset") FolderDatasetT = TypeVar("FolderDatasetT", bound="FolderDataset") @@ -284,96 +293,8 @@ def to_pytorch_dataset( "Only one of `representation` and `factory` can be given." ) - try: - # pylint: disable=import-outside-toplevel - from torch.utils.data import Dataset as TorchDataset - except ImportError as err: - raise ImportError("Optional package pytorch is required.") from err - - class TorchMusicFactoryDataset(TorchDataset): - """A PyTorch dataset built from a Music dataset. - - Parameters - ---------- - dataset : :class:`muspy.Dataset` - Dataset object to base on. - factory : Callable - Function to be applied to the Music objects. The input - is a Music object, and the output is an array or a - tensor. - - """ - - def __init__( - self, - dataset: Dataset, - factory: Callable, - subset: str = "Full", - indices: Sequence[int] = None, - ): - super().__init__() - self.dataset = dataset - self.factory = factory - self.subset = subset - self.indices = indices - if self.indices is not None: - self.indices = sorted( - idx for idx in self.indices if idx < len(self.dataset) - ) - - def __repr__(self) -> str: - return ( - f"TorchMusicFactoryDataset(dataset={self.dataset}, " - f"factory={self.subset}, subset={self.factory})" - ) - - def __getitem__(self, index): - if self.indices is None: - return self.factory(self.dataset[index]) - return self.factory(self.dataset[self.indices[index]]) - - def __len__(self) -> int: - if self.indices is None: - return len(self.dataset) - return len(self.indices) - - class TorchRepresentationDataset(TorchMusicFactoryDataset): - """A PyTorch music dataset. - - Parameters - ---------- - dataset : :class:`muspy.Dataset` - Dataset object to base on. - representation : str - Target representation. See - :func:`muspy.to_representation()` for available - representation. - - """ - - def __init__( - self, - dataset: Dataset, - representation: str, - subset: str = "Full", - indices: Sequence[int] = None, - **kwargs: Any, - ): - self.representation = representation - - def factory(music): - return music.to_representation(representation, **kwargs) - - super().__init__( - dataset, factory=factory, subset=subset, indices=indices - ) - - def __repr__(self) -> str: - return ( - f"TorchRepresentationDataset(dataset={self.dataset}, " - f"representation={self.representation}, " - f"subset={self.subset})" - ) + if not TORCH_AVAILABLE: + raise ImportError("Optional package pytorch is required.") # No split if splits is None: @@ -1241,3 +1162,96 @@ def __init__( ignore_exceptions=ignore_exceptions, use_converted=use_converted, ) + + +class TorchMusicFactoryDataset(TorchDataset): + """A PyTorch dataset built from a Music dataset. + + Parameters + ---------- + dataset : :class:`muspy.Dataset` + Dataset object to base on. + factory : Callable + Function to be applied to the Music objects. The input + is a Music object, and the output is an array or a + tensor. + + """ + + def __init__( + self, + dataset: Dataset, + factory: Callable, + subset: str = "Full", + indices: Sequence[int] = None, + ): + super().__init__() + self.dataset = dataset + self.factory = factory + self.subset = subset + self.indices = indices + if self.indices is not None: + self.indices = sorted( + idx for idx in self.indices if idx < len(self.dataset) + ) + + def __repr__(self) -> str: + return ( + f"TorchMusicFactoryDataset(dataset={self.dataset}, " + f"factory={self.subset}, subset={self.factory})" + ) + + def __getitem__(self, index): + if self.indices is None: + return self.factory(self.dataset[index]) + return self.factory(self.dataset[self.indices[index]]) + + def __len__(self) -> int: + if self.indices is None: + return len(self.dataset) + return len(self.indices) + + +def _torch_representation_factory(music, representation: str, kwargs): + return music.to_representation(representation, **kwargs) + + +class TorchRepresentationDataset(TorchMusicFactoryDataset): + """A PyTorch music dataset. + + Parameters + ---------- + dataset : :class:`muspy.Dataset` + Dataset object to base on. + representation : str + Target representation. See + :func:`muspy.to_representation()` for available + representation. + + """ + + def __init__( + self, + dataset: Dataset, + representation: str, + subset: str = "Full", + indices: Sequence[int] = None, + **kwargs: Any, + ): + self.representation = representation + factory = partial( + _torch_representation_factory, + representation=representation, + kwargs=kwargs, + ) + + super().__init__( + dataset, factory=factory, subset=subset, indices=indices + ) + + def __repr__(self) -> str: + return ( + f"TorchRepresentationDataset(dataset={self.dataset}, " + f"representation={self.representation}, " + f"subset={self.subset})" + ) From 68d46929a441755feff4ed9b49108ceacbd110a1 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 2 Jan 2023 15:19:57 -0500 Subject: [PATCH 2/2] add test to verify pickling works --- tests/test_datasets.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bdc37b7..dd2cb37 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,4 +1,5 @@ """Test cases for dataset module.""" +import pickle import shutil import pytest @@ -102,6 +103,18 @@ def test_to_pytorch_dataset(): assert pytorch_dataset[0] is not None +def test_pickle_pytorch_dataset(): + """ + PyTorch datasets must support pickling so that the dataloader can + use multiple workers when assembling a batch. + """ + dataset = Music21Dataset("demos") + pytorch_dataset = dataset.to_pytorch_dataset(representation="pitch") + obj = pickle.dumps(pytorch_dataset) + dataset = pickle.loads(obj) + assert dataset[0] is not None + + def test_to_tensorflow_dataset(): tf.config.set_visible_devices([], "GPU") dataset = Music21Dataset("demos")