diff --git a/launch.sh b/launch.sh index 69a1aff00e..34d9e5855d 100755 --- a/launch.sh +++ b/launch.sh @@ -103,7 +103,7 @@ build() { } # Check Docker version - docker_version=$(docker --version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + docker_version=$(docker --version | awk -F'[, ]' '{print $3}') required_docker_version="23.0.1" if ! version_ge "$docker_version" "$required_docker_version"; then @@ -112,7 +112,7 @@ build() { fi # Check Buildx version - buildx_version=$(docker buildx version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + buildx_version=$(docker buildx version | awk '{print $2}') required_buildx_version="0.10.2" if ! version_ge "$buildx_version" "$required_buildx_version"; then diff --git a/sub-packages/bionemo-webdatamodule/LICENSE b/sub-packages/bionemo-webdatamodule/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sub-packages/bionemo-webdatamodule/README.md b/sub-packages/bionemo-webdatamodule/README.md new file mode 100644 index 0000000000..b06442c66a --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/README.md @@ -0,0 +1,353 @@ +# bionemo-webdatamodule + +To install, execute the following: +```bash +pip install -e . +``` + +To run unit tests, execute: +```bash +pytest -v . +``` + +## WebDataModule + +```python +class WebDataModule(L.LightningDataModule) +``` + +A LightningDataModule for using webdataset tar files to setup dataset and +dataloader. This data module takes as input a dictionary: Split -> tar file +directory and vaiours webdataset config settings. In its setup() function, it +creates the webdataset object chaining up the input `pipeline_wds` workflow. In +its train/val/test_dataloader(), it creates the WebLoader object chaining up the +`pipeline_prebatch_wld` workflow + +Examples +-------- + +1. create the data module with input directory to webdataset tar files. +Depending on which of the downstream Lightning.Trainer methods are called, +e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or +`Trainer.predict()`, only a subset of the train, val and test splits need to +be specified in the various input options to the data module: + +- `Trainer.fit()` requires the `train` and `val` splits +- `Trainer.validate()` requires the `val` split +- `Trainer.test()` requires the `test` splits +- `Trainer.predict()` requires the `test` splits + +Here is an example of constructing the data module for `Trainer.fit()`: +``` +>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule +>>> +>>> tar_file_prefix = "shards" +>>> +>>> dirs_of_tar_files = { +>>> Split.train: "/path/to/train/split/tars", +>>> Split.val: "/path/to/val/split/tars", +>>> } +>>> +>>> n_samples { +>>> Split.train: 1000, +>>> Split.val: 100, +>>> } +>>> +>>> # this is the string to retrieve the corresponding data object from the +>>> # webdataset file (see +>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format +>>> # for details) +>>> suffix_keys_wds = "tensor.pyd" +>>> +>>> # see the API doc for the definition of global_batch_size +>>> global_batch_size = 16 +>>> +>>> seed = 27193781 +>>> +>>> # Specify the routines to process the samples in the WebDataset object. +>>> # The routine is a generator of an Iterable of generators that are chained +>>> # together by nested function calling. The following is equivalent of +>>> # defining a overall generator of `shuffle(untuple(...))` which +>>> # untuples the samples and shuffles them. See webdataset's Documentation +>>> # for details. +>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's +>>> # file parsing rule. +>>> +>>> untuple = lambda source : (sample for (sample,) in source) +>>> +>>> from webdatast import shuffle +>>> pipeline_wds = { +>>> Split.train : [untuple, shuffle(n_samples[Split.train], +>>> rng=random.Random(seed_rng_shfl))], +>>> Split.val: untuple +>>> } +>>> +>>> # Similarly the user can optionally define the processing routine on the +>>> # WebLoader (the dataloader of webdataset). +>>> # NOTE: these routines by default take unbatched sample as input so the +>>> # user can customize their batching routines here +>>> +>>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) +>>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } +>>> +>>> # the user can optionally specify the kwargs for WebDataset and +>>> # WebLoader +>>> +>>> kwargs_wds = { +>>> split : {'shardshuffle' : split == Split.train, +>>> 'nodesplitter' : wds.split_by_node, +>>> 'seed' : seed_rng_shfl} +>>> for split in Split +>>> } +>>> +>>> kwargs_wld = { +>>> split : {"num_workers": 2} for split in Split +>>> } +>>> +>>> # construct the data module +>>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) +``` + + + +#### \_\_init\_\_ + +```python +def __init__( + dirs_tars_wds: Dict[Split, str], + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + pipeline_prebatch_wld: Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None) +``` + +constructor + +**Arguments**: + +- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split +- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of + data samples for each split +- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files +- `global_batch_size` _int_ - size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: +- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], +- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader +- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebDataset.__init__() +- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. + +Returns: None + + + +#### setup + +```python +def setup(stage: str) -> None +``` + +This is called on all Lightning-managed nodes in a multi-node +training session + + +**Arguments**: + +- `stage` _str_ - "fit", "test" or "predict" +- `Returns` - None + +## PickledDataWDS + +```python +class PickledDataWDS(WebDataModule) +``` + +A LightningDataModule to process pickled data into webdataset tar files +and setup dataset and dataloader. This inherits the webdataset setup from +its parent module `WebDataModule`. This data module takes a directory of +pickled data files, data filename prefixes for train/val/test splits, data +filename suffixes and prepare webdataset tar files by globbing the specific +pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and +outputing to webdataset tar file with the dict structure: +``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } +``` +NOTE: this assumes only one pickled file is processed for each sample. In +its setup() function, it creates the webdataset object chaining up the input +`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the +WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + +Examples +-------- + +1. create the data module with a directory of pickle files and the file name +prefix thereof for different splits to used by `Lightning.Trainer.fit()` + +``` +>>> from bionemo.webdatamodule.datamodule import Split, PickledDataWDS + +>>> dir_pickles = "/path/to/my/pickles/dir" + +>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the +>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the +>>> # validation dataset + +>>> suffix_pickles = "mydata.pt" + +>>> names_subset = { +>>> Split.train: [sample1, sample2], +>>> Split.val: [sample4, sample5], +>>> } + +>>> # the following setting will attempt to create at least 5 tar files in +>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + +>>> n_tars_wds = 5 +>>> prefix_tars_wds = "myshards" +>>> output_dir_tar_files = "/path/to/output/tars/dir" + +>>> # see the `WebDataModule` API doc for the definition of global_batch_size +>>> global_batch_size = 16 + +>>> # user can optionally customize the data processing routines and kwargs used +>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + +>>> pipeline_wds = { Split.train: ... } + +>>> pipeline_prebatch_wld = { Split.train: ... } + +>>> kwargs_wds = { Split.train: ..., Split.val: ... } + +>>> kwargs_wld = { Split.train: ..., Split.val: ... } + +>>> # create the data module +>>> data_module = PickledDataWDS( +>>> dir_pickles, +>>> suffix_pickles, +>>> names_subset, +>>> output_dir_tar_files, +>>> global_batch_size, # `WebDataModule` args +>>> n_tars_wds=n_tars_wds, +>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs +>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs +>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs +>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs +>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs +>>> ) + +``` + + + +#### \_\_init\_\_ + +```python +def __init__(dir_pickles: str, + suffix_pickles: str, + names_subset: Dict[Split, List[str]], + prefix_dir_tars_wds: str, + *args, + n_tars_wds: Optional[int] = None, + **kwargs) +``` + +constructor + +**Arguments**: + +- `dir_pickles` _str_ - input directory of pickled data files +- `suffix_pickles` _str_ - filename suffix of the input data in + dir_pickles. This is also used as the key mapped to the + tarballed pickled object in the webdataset +- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split +- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. +- `*args` - arguments passed to the parent WebDataModule + + Kwargs: +- `n_tars_wds` _int_ - attempt to create at least this number of + webdataset shards +- `**kwargs` - arguments passed to the parent WebDataModule + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. The nesting +`pickles_to_tars` function goes through the data name prefixes in the +different splits, read the corresponding pickled file and output a +webdataset tar archive with the dict structure: {"__key__" : +name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + +Returns: None diff --git a/sub-packages/bionemo-webdatamodule/pyproject.toml b/sub-packages/bionemo-webdatamodule/pyproject.toml new file mode 100644 index 0000000000..30a44da5c5 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +# For guidance, see: https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ +[project] +name = "bionemo-webdatamodule" +version = "0.0.1" +authors = [ + { name = "Dejun Lin", email = "dejunl@nvidia.com" }, +] +description = "" +readme = "README.md" +requires-python = ">=3.10" +keywords = [] +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3.10", + "Private :: Do Not Upload", +] +dynamic = ["dependencies"] + +[project.optional-dependencies] +test = [ + "pytest", +] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +lint.ignore = ["C901", "E741", "E501", "E731"] diff --git a/sub-packages/bionemo-webdatamodule/requirements.txt b/sub-packages/bionemo-webdatamodule/requirements.txt new file mode 100644 index 0000000000..24ef528b0d --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/requirements.txt @@ -0,0 +1 @@ +webdataset==0.2.96 diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py new file mode 100644 index 0000000000..33fa7936b9 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -0,0 +1,491 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +from enum import Enum, auto +from typing import Any, Dict, Iterable, List, Optional, Union, get_args + +import lightning as L +import webdataset as wds + +from bionemo.webdatamodule.utils import pickles_to_tars + + +class Split(Enum): + train = auto() + val = auto() + test = auto() + + +class WebDataModule(L.LightningDataModule): + """A LightningDataModule for using webdataset tar files to setup dataset and + dataloader. This data module takes as input a dictionary: Split -> tar file + directory and vaiours webdataset config settings. In its setup() function, + it creates the webdataset object chaining up the input `pipeline_wds` + workflow. In its train/val/test_dataloader(), it creates the WebLoader + object chaining up the `pipeline_prebatch_wld` workflow + + Examples + -------- + + 1. create the data module with input directory to webdataset tar files. + Depending on which of the downstream Lightning.Trainer methods are called, + e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or + `Trainer.predict()`, only a subset of the train, val and test splits need to + be specified in the various input options to the data module: + + - `Trainer.fit()` requires the `train` and `val` splits + - `Trainer.validate()` requires the `val` split + - `Trainer.test()` requires the `test` splits + - `Trainer.predict()` requires the `test` splits + + Here is an example of constructing the data module for `Trainer.fit()`: + ``` + >>> from bionemo.core.data.datamodule import Split, WebDataModule + >>> + >>> tar_file_prefix = "shards" + >>> + >>> dirs_of_tar_files = { + >>> Split.train: "/path/to/train/split/tars", + >>> Split.val: "/path/to/val/split/tars", + >>> } + >>> + >>> n_samples { + >>> Split.train: 1000, + >>> Split.val: 100, + >>> } + >>> + >>> # this is the string to retrieve the corresponding data object from the + >>> # webdataset file (see + >>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format + >>> # for details) + >>> suffix_keys_wds = "tensor.pyd" + >>> + >>> # see the API doc for the definition of global_batch_size + >>> global_batch_size = 16 + >>> + >>> seed = 27193781 + >>> + >>> # Specify the routines to process the samples in the WebDataset object. + >>> # The routine is a generator of an Iterable of generators that are chained + >>> # together by nested function calling. The following is equivalent of + >>> # defining a overall generator of `shuffle(untuple(...))` which + >>> # untuples the samples and shuffles them. See webdataset's Documentation + >>> # for details. + >>> # NOTE: the `untuple` is almost always necessary due to the webdataset's + >>> # file parsing rule. + >>> + >>> untuple = lambda source : (sample for (sample,) in source) + >>> + >>> from webdatast import shuffle + >>> pipeline_wds = { + >>> Split.train : [untuple, shuffle(n_samples[Split.train], + >>> rng=random.Random(seed_rng_shfl))], + >>> Split.val: untuple + >>> } + >>> + >>> # Similarly the user can optionally define the processing routine on the + >>> # WebLoader (the dataloader of webdataset). + >>> # NOTE: these routines by default take unbatched sample as input so the + >>> # user can customize their batching routines here + >>> + >>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) + >>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } + >>> + >>> # the user can optionally specify the kwargs for WebDataset and + >>> # WebLoader + >>> + >>> kwargs_wds = { + >>> split : {'shardshuffle' : split == Split.train, + >>> 'nodesplitter' : wds.split_by_node, + >>> 'seed' : seed_rng_shfl} + >>> for split in Split + >>> } + >>> + >>> kwargs_wld = { + >>> split : {"num_workers": 2} for split in Split + >>> } + >>> + >>> # construct the data module + >>> data_module = WebDataModule(n_samples, suffix_keys_wds, + dirs_of_tar_files, global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) + ``` + + """ + + def __init__( + self, + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + dirs_tars_wds: Dict[Split, str], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, + pipeline_prebatch_wld: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None, + ): + """constructor + + Args: + n_samples (Dict[Split, int]): input dictionary: Split -> number of + data samples for each split + suffix_keys_wds (Union[str, Iterable[str]]): a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files + dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split + global_batch_size (int): size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: + prefix_tars_wds (str): name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]]): a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader + kwargs_wds (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebDataset.__init__() + kwargs_wld (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + """ + super().__init__() + + self._dirs_tars_wds = dirs_tars_wds + + keys_subset = self._dirs_tars_wds.keys() + + if n_samples.keys() != keys_subset: + raise RuntimeError( + f"Input n_samples has different keys than " + f"dirs_tars_wds: {n_samples.keys()} vs " + f"{keys_subset}" + ) + + self._n_samples = n_samples + + self._global_batch_size = global_batch_size + + if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])): + raise TypeError("suffix_keys_wds can only be str or Iterable[str]") + + self._suffix_keys_wds = suffix_keys_wds + + self._prefix_tars_wds = prefix_tars_wds + self._pipeline_wds = pipeline_wds + self._pipeline_prebatch_wld = pipeline_prebatch_wld + + self._kwargs_wld = kwargs_wld + + self._kwargs_wds = kwargs_wds + + # to be created later in setup + self._dataset = {} + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses. + + Returns: None + """ + pass + + def _setup_wds(self, split: Split) -> wds.WebDataset: + """setup webdataset and webloader. This is called by setup() + + Args: + split (Split): train, val or test split + + Returns: WebDataset + + """ + if split not in self._dirs_tars_wds.keys(): + raise RuntimeError( + f"_setup_wds() is called with {split} " + f"split that doesn't have the input tar dir" + ) + urls = sorted( + glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") + ) + kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None + dataset = wds.WebDataset( + urls, **(kwargs if kwargs is not None else {}) + ).decode() + if isinstance(self._suffix_keys_wds, str): + dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}") + else: + dataset = dataset.extract_keys( + *[f"*.{key}" for key in self._suffix_keys_wds] + ) + + if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: + if isinstance(self._pipeline_wds[split], Iterable): + dataset = dataset.compose(*self._pipeline_wds[split]) + else: + dataset = dataset.compose(self._pipeline_wds[split]) + return dataset + + def setup(self, stage: str) -> None: + """This is called on all Lightning-managed nodes in a multi-node + training session + + + Args: + stage (str): "fit", "test" or "predict" + Returns: None + """ + if stage == "fit": + self._dataset[Split.train] = self._setup_wds(Split.train) + self._dataset[Split.val] = self._setup_wds(Split.val) + elif stage == "validate": + self._dataset[Split.val] = self._setup_wds(Split.val) + elif stage == "test": + self._dataset[Split.test] = self._setup_wds(Split.test) + elif stage == "predict": + self._dataset[Split.test] = self._setup_wds(Split.test) + else: + raise NotImplementedError( + f"Data setup with stage = {stage} " f"is not implmented" + ) + + def _setup_dataloader(self, split: Split) -> wds.WebLoader: + """setup the dataloader for the input dataset split + + Args: + split (Split): input split type + + Returns: WebLoader object + + """ + if self._dataset[split] is None: + raise RuntimeError( + f"_setup_dataloader() is called with {split} " + f"split without setting up the corresp. dataset" + ) + dataset = self._dataset[split] + n_samples = self._n_samples[split] + n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size + kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None + loader = wds.WebLoader( + dataset, batch_size=None, **(kwargs if kwargs is not None else {}) + ) + + if ( + self._pipeline_prebatch_wld is not None + and self._pipeline_prebatch_wld[split] is not None + ): + if isinstance(self._pipeline_prebatch_wld[split], Iterable): + loader = loader.compose(*self._pipeline_prebatch_wld[split]) + else: + loader = loader.compose(self._pipeline_prebatch_wld[split]) + + loader = loader.with_epoch(n_batches) + + return loader + + def train_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.train) + + def val_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.val) + + def test_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.test) + + def predict_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.test) + + +class PickledDataWDS(WebDataModule): + """A LightningDataModule to process pickled data into webdataset tar files + and setup dataset and dataloader. This inherits the webdataset setup from + its parent module `WebDataModule`. This data module takes a directory of + pickled data files, data filename prefixes for train/val/test splits, data + filename suffixes and prepare webdataset tar files by globbing the specific + pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and + outputing to webdataset tar file with the dict structure: + ``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } + ``` + NOTE: this assumes only one pickled file is processed for each sample. In + its setup() function, it creates the webdataset object chaining up the input + `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the + WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + + Examples + -------- + + 1. create the data module with a directory of pickle files and the file name + prefix thereof for different splits to used by `Lightning.Trainer.fit()` + + ``` + >>> from bionemo.core.data.datamodule import Split, PickledDataWDS + + >>> dir_pickles = "/path/to/my/pickles/dir" + + >>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the + >>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the + >>> # validation dataset + + >>> suffix_pickles = "mydata.pt" + + >>> names_subset = { + >>> Split.train: [sample1, sample2], + >>> Split.val: [sample4, sample5], + >>> } + + >>> # the following setting will attempt to create at least 5 tar files in + >>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + + >>> n_tars_wds = 5 + >>> prefix_tars_wds = "myshards" + >>> output_dir_tar_files = { + Split.train : "/path/to/output/tars/dir-train", + Split.val : "/path/to/output/tars/dir-val", + Split.test : "/path/to/output/tars/dir-test", + } + + >>> # see the `WebDataModule` API doc for the definition of global_batch_size + >>> global_batch_size = 16 + + >>> # user can optionally customize the data processing routines and kwargs used + >>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + + >>> pipeline_wds = { Split.train: ... } + + >>> pipeline_prebatch_wld = { Split.train: ... } + + >>> kwargs_wds = { Split.train: ..., Split.val: ... } + + >>> kwargs_wld = { Split.train: ..., Split.val: ... } + + >>> # create the data module + >>> data_module = PickledDataWDS( + >>> dir_pickles, + >>> names_subset, + >>> suffix_pickles, # `WebDataModule` args + >>> output_dir_tar_files, # `WebDataModule` args + >>> global_batch_size, # `WebDataModule` args + >>> n_tars_wds=n_tars_wds, + >>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs + >>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs + >>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs + >>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs + >>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs + >>> ) + + ``` + + """ + + def __init__( + self, + dir_pickles: str, + names_subset: Dict[Split, List[str]], + *args, + n_tars_wds: Optional[int] = None, + **kwargs, + ): + """constructor + + Args: + dir_pickles (str): input directory of pickled data files + names_subset (Dict[Split, List[str]]): list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split + *args: arguments passed to the parent WebDataModule after its + `n_samples` args (where `n_samples` is deduced from the length of + `names_subset` arg of this class) + + Kwargs: + n_tars_wds (int): attempt to create at least this number of + webdataset shards + **kwargs: arguments passed to the parent WebDataModule + + + """ + super().__init__( + {split: len(names_subset[split]) for split in names_subset.keys()}, + *args, + **kwargs, + ) + + self._dir_pickles = dir_pickles + + self._names_subset = names_subset + + self._n_tars_wds = n_tars_wds + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses. The nesting + `pickles_to_tars` function goes through the data name prefixes in the + different splits, read the corresponding pickled file and output a + webdataset tar archive with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + + Returns: None + """ + for split in self._names_subset.keys(): + # create wds shards (tar files) for train set + pickles_to_tars( + self._dir_pickles, + self._names_subset[split], + self._suffix_keys_wds, + self._dirs_tars_wds[split], + self._prefix_tars_wds, + min_num_shards=self._n_tars_wds, + ) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py new file mode 100644 index 0000000000..541957edd7 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import pickle +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_args + +import webdataset as wds +from nemo.utils import logging + + +def pickles_to_tars( + dir_input: str, + input_prefix_subset: List[str], + input_suffix: Union[str, Iterable[str]], + dir_output: str, + output_prefix: str, + func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix, + suffix_to_data: {"__key__": prefix, **suffix_to_data}, + min_num_shards: Optional[int] = None, +) -> None: + """Convert a subset of pickle files from a directory to Webdataset tar files + Input path and name pattern for sample 0: + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" + Input path and name pattern for sample 1: + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" + ... + Output path and name pattern: + f"{dir_output}/{output_prefix}-%06d.tar" + + The webdataset tar archive is specified by the dictionary: + { + "__key__" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... + } + so that parsing the tar archive is equivalent of reading + {sample_filename_preifx}.{sample_filename_suffix_1} etc. + + Here, each sample data get its name prefix from one element of + `input_prefix_subset` and its name suffixes from the list `input_suffix`. + Per the webdataset file format specification, the `sample_filename_preifx` + can't contain dots '.' so this function removes it for the user by calling + .replace(".", "-") on the elements of `input_prefix_subset` + + Args: + dir_input (str): Input directory + input_prefix_subset (List[str]): Input subset of pickle files' prefix + input_suffix (Union[str, Iterable[str]]): Input pickle file name + suffixes, each for one type of data object, for all the samples + dir_output (str): Output directory + output_prefix (str): Output tar file name prefix + func_output_data (Callable[[str, Dict[str, Any]], Dict[str, Any]]) : + function that maps the name prefix, name suffix and data object to a + webdataset tar archive dictionary. Refer to the webdataset github + repo for the archive file format specification. + min_num_shards (int) : create at least this number of tar files. + WebDataset has bugs when reading small number of tar files in a + multi-node lightening + DDP setting so this option can be used to + guarantee the tar file counts + + Returns: None + + """ + if not isinstance(input_suffix, get_args(Union[str, Iterable])): + raise TypeError("input_suffix can only be str or Iterable[str]") + os.makedirs(dir_output, exist_ok=True) + wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") + n_samples_per_shard_max = 100000 + if min_num_shards is not None: + if min_num_shards <= 0: + raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0") + n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards + with wds.ShardWriter( + wd_subset_pattern, + encoder=False, + maxcount=n_samples_per_shard_max, + compress=False, + mode=0o777, + ) as sink: + for name in input_prefix_subset: + try: + if isinstance(input_suffix, str): + suffix_to_data = { + input_suffix: pickle.dumps( + pickle.loads( + ( + Path(dir_input) / f"{name}.{input_suffix}" + ).read_bytes() + ) + ) + } + else: + suffix_to_data = { + suffix: pickle.dumps( + pickle.loads( + (Path(dir_input) / f"{name}.{suffix}").read_bytes() + ) + ) + for suffix in input_suffix + } + # the prefix name shouldn't contain any "." per webdataset's + # specification + sample = func_output_data(name.replace(".", "-"), suffix_to_data) + sink.write(sample) + except ModuleNotFoundError as e: + logging.error( + f"Dependency for parsing input pickle data not found: {e}" + ) + raise e + except Exception as e: + logging.error(f"Failed to write {name} into tar files due to error {e}") + raise e diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py new file mode 100644 index 0000000000..a43f4c0be3 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pickle +import random + +import lightning as L +import pytest +import torch +import webdataset as wds +from webdataset.filters import batched, shuffle + +from bionemo.webdatamodule.datamodule import PickledDataWDS, Split, WebDataModule +from bionemo.webdatamodule.utils import pickles_to_tars + + +@pytest.fixture(scope="module") +def gen_pickle_files(tmp_path_factory): + dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() + prefix_sample = "sample" + suffix_sample = ["tensor.pyd", "tensor_copy.pyd"] + n_samples_per_split = 10 + prefixes = [] + # generate the pickles for train, val, and test + for i in range(n_samples_per_split * 3): + prefix = f"{prefix_sample}-{i:04}" + prefixes.append(prefix) + t = torch.tensor(i, dtype=torch.int32) + for suffix in suffix_sample: + with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh: + pickle.dump(t, fh) + prefixes_pickle = { + Split.train: prefixes[0:n_samples_per_split], + Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], + Split.test: prefixes[n_samples_per_split * 2 : n_samples_per_split * 3], + } + return ( + dir_pickles, + prefix_sample, + suffix_sample, + prefixes_pickle, + n_samples_per_split, + ) + + +@pytest.fixture(scope="module", params=[1, 2]) +def gen_test_data(tmp_path_factory, gen_pickle_files, request): + dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = ( + gen_pickle_files + ) + n_suffixes = request.param + if n_suffixes <= 1: + suffix_sample = suffixes[0] + else: + suffix_sample = suffixes[0:n_suffixes] + dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() + dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} + prefix_tar = "tensor" + n_samples = {split: n_samples_per_split for split in Split} + # generate the tars + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.train], + suffix_sample, + dir_tars[Split.train], + prefix_tar, + min_num_shards=3, + ) + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.val], + suffix_sample, + dir_tars[Split.val], + prefix_tar, + min_num_shards=3, + ) + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.test], + suffix_sample, + dir_tars[Split.test], + prefix_tar, + min_num_shards=3, + ) + return ( + dir_pickles, + dir_tars, + prefix_sample, + suffix_sample, + prefix_tar, + n_samples, + prefixes_pickle, + ) + + +def _create_webdatamodule(gen_test_data, num_workers=2): + (_, dirs_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples, _) = ( + gen_test_data + ) + local_batch_size = 2 + global_batch_size = 2 + seed_rng_shfl = 82838392 + + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) + + if isinstance(suffix_keys_wds, str): + untuple = lambda source: (sample[0] for sample in source) + elif isinstance(suffix_keys_wds, list): + untuple = lambda source: (torch.vstack(sample) for sample in source) + + pipeline_wds = { + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], + Split.val: untuple, + Split.test: untuple, + } + + pipeline_prebatch_wld = { + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], + Split.val: batch, + Split.test: batch, + } + + kwargs_wds = { + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } + for split in Split + } + + kwargs_wld = {split: {"num_workers": num_workers} for split in Split} + + data_module = WebDataModule( + n_samples, + suffix_keys_wds, + dirs_tars_wds, + global_batch_size, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) + + return data_module, dirs_tars_wds + + +@pytest.fixture(scope="module") +def create_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) + + +@pytest.fixture(scope="module") +def create_another_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) + + +@pytest.fixture(scope="module") +def create_webdatamodule_with_5_workers(gen_test_data): + return _create_webdatamodule(gen_test_data, num_workers=5) + + +class ModelTestWebDataModule(L.LightningModule): + def __init__(self) -> None: + super().__init__() + self._model = torch.nn.Linear(1, 1) + self._samples = {split: [] for split in Split} + + def forward(self, x): + return self._model(x.float()) + + def training_step(self, batch): + self._samples[Split.train].append(batch) + loss = self(batch).sum() + return loss + + def validation_step(self, batch, batch_index): + self._samples[Split.val].append(batch) + return torch.zeros(1) + + def test_step(self, batch, batch_index): + self._samples[Split.test].append(batch) + + def predict_step(self, batch, batch_index): + self._samples[Split.test].append(batch) + return torch.zeros(1) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) + return optimizer + + +@pytest.fixture(scope="function") +def create_trainer_and_model(): + trainer = L.Trainer( + max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1 + ) + model = ModelTestWebDataModule() + return trainer, model + + +def _create_pickleddatawds(tmp_path_factory, gen_test_data): + ( + dir_pickles, + _, + _, + suffix_keys_wds, + prefix_tars_wds, + n_samples, + names, + ) = gen_test_data + local_batch_size = 2 + global_batch_size = 2 + seed_rng_shfl = 82838392 + n_tars_wds = 3 + + prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() + dirs_tars_wds = {s: f"{prefix_dir_tars_wds}{str(s).split('.')[-1]}" for s in Split} + + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) + + untuple = lambda source: (sample[0] for sample in source) + + pipeline_wds = { + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], + Split.val: untuple, + Split.test: untuple, + } + + pipeline_prebatch_wld = { + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], + Split.val: batch, + Split.test: batch, + } + + kwargs_wds = { + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } + for split in Split + } + + kwargs_wld = {split: {"num_workers": 2} for split in Split} + + data_module = PickledDataWDS( + dir_pickles, + names, + suffix_keys_wds, + dirs_tars_wds, + global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) + + return data_module, dirs_tars_wds, n_tars_wds + + +@pytest.fixture(scope="module") +def create_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) + + +@pytest.fixture(scope="module") +def create_another_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py new file mode 100644 index 0000000000..692905a416 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +from enum import Enum, auto + +import lightning as L +import pytest +import torch + +from bionemo.webdatamodule.datamodule import Split + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_init(split, create_webdatamodule): + data_module, dirs_tars_wds = create_webdatamodule + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + assert data_module._dirs_tars_wds[split] == f"{dirs_tars_wds[split]}", ( + f"Wrong tar files directory: " + f"expected {dirs_tars_wds[split]} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataset( + split, create_webdatamodule, create_another_webdatamodule +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + for sample in m._dataset[split]: + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataloader( + split, create_webdatamodule, create_another_webdatamodule +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + loader = None + if split == Split.train: + loader = m.train_dataloader() + elif split == Split.val: + loader = m.val_dataloader() + elif split == Split.test: + loader = m.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + for samples in loader: + # PyG's HeteroDataBatch is Batch inherited from HeteroData + assert isinstance( + samples, torch.Tensor + ), "Sample object is not torch.Tensor" + tensors.append(samples) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataloader" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_throw_on_many_workers( + split, create_webdatamodule_with_5_workers +): + data_module = create_webdatamodule_with_5_workers[0] + urls = glob.glob( + f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar" + ) + n_tars = len(urls) + data_module._kwargs_wld[split]["num_workers"] = n_tars + 1 + data_module.prepare_data() + data_module.setup("fit") + data_module.setup("test") + loader = None + if split == Split.train: + loader = data_module.train_dataloader() + elif split == Split.val: + loader = data_module.val_dataloader() + elif split == Split.test: + loader = data_module.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + try: + for _ in loader: + pass + except ValueError as e: + # this is expected + assert "have fewer shards than workers" in str(e), ( + f"'have fewer shards than workers' not found in exception " + f"raised from data loading: {e}" + ) + except Exception as e: + raise RuntimeError( + f"WebLoader doesn't raise ValueError with fewer " + f"shards than workers but raise this instead: {e}" + ) + else: + raise NotImplementedError( + "WebLoader doesn't throw error with num_workers > num_shards " + "User should report this issue to webdataset and create " + "less shards than workers in practice as a workaround" + ) + + +class Stage(Enum): + fit = auto() + validate = auto() + test = auto() + predict = auto() + + +@pytest.mark.parametrize("stage", list(Stage)) +def test_webdatamodule_in_lightning( + stage, create_webdatamodule, create_another_webdatamodule, create_trainer_and_model +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + trainer, model = create_trainer_and_model + # get the list of samples from the loader + L.seed_everything(2823828) + data_modules[0].prepare_data() + split = None + if stage == Stage.fit: + split = Split.train + elif stage == Stage.validate: + split = Split.val + elif stage == Stage.test or stage == Stage.predict: + split = Split.test + else: + raise RuntimeError(f"{stage} stage not implemented") + name_stage = str(stage).split(".")[-1] + data_modules[0].setup(name_stage) + # get the list of samples from the workflow + get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") + loader = get_dataloader() + L.seed_everything(2823828) + workflow = getattr(trainer, name_stage) + workflow(model, data_modules[1]) + device = model._samples[split][0].device + samples = [sample.to(device=device) for sample in loader] + torch.testing.assert_close( + torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_init(split, create_pickleddatawds): + data_module, dirs_tars_wds, _ = create_pickleddatawds + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + assert data_module._dirs_tars_wds[split] == dirs_tars_wds[split], ( + f"Wrong tar files directory: " + f"expected {dirs_tars_wds[split]} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_prepare_data(split, create_pickleddatawds): + data_module, _, n_tars_min = create_pickleddatawds + data_module.prepare_data() + dir_tars = f"{data_module._dirs_tars_wds[split]}" + tars = glob.glob(f"{dir_tars}/{data_module._prefix_tars_wds}-*.tar") + n_tars = len(tars) + assert n_tars_min <= n_tars and n_tars <= n_tars_min + 1, ( + f"Number of tar files: {n_tars} in {dir_tars} is outside the range " + f"[{n_tars_min}, {n_tars_min + 1}]" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_setup_dataset( + split, create_pickleddatawds, create_another_pickleddatawds +): + data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + for sample in m._dataset[split]: + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +def test_pickleddatawds_sample_overlap(create_pickleddatawds): + data_module = create_pickleddatawds[0] + # this writes the tar files to disk + data_module.prepare_data() + # read the data back by setting up the dataset object and loop over it + data_module.setup("fit") + data_module.setup("test") + results = { + split: set([sample.item() for sample in data_module._dataset[split]]) + for split in Split + } + overlap_train_val = results[Split.train] & results[Split.val] + overlap_train_test = results[Split.train] & results[Split.test] + overlap_val_test = results[Split.val] & results[Split.test] + assert ( + len(overlap_train_val) == 0 + ), "Shared samples found between train and val datasets" + assert ( + len(overlap_train_test) == 0 + ), "Shared samples found between train and test datasets" + assert ( + len(overlap_val_test) == 0 + ), "Shared samples found between val and test datasets"