diff --git a/launch.sh b/launch.sh
index 69a1aff00e..34d9e5855d 100755
--- a/launch.sh
+++ b/launch.sh
@@ -103,7 +103,7 @@ build() {
}
# Check Docker version
- docker_version=$(docker --version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+')
+ docker_version=$(docker --version | awk -F'[, ]' '{print $3}')
required_docker_version="23.0.1"
if ! version_ge "$docker_version" "$required_docker_version"; then
@@ -112,7 +112,7 @@ build() {
fi
# Check Buildx version
- buildx_version=$(docker buildx version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+')
+ buildx_version=$(docker buildx version | awk '{print $2}')
required_buildx_version="0.10.2"
if ! version_ge "$buildx_version" "$required_buildx_version"; then
diff --git a/sub-packages/bionemo-webdatamodule/LICENSE b/sub-packages/bionemo-webdatamodule/LICENSE
new file mode 100644
index 0000000000..d645695673
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/sub-packages/bionemo-webdatamodule/README.md b/sub-packages/bionemo-webdatamodule/README.md
new file mode 100644
index 0000000000..b06442c66a
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/README.md
@@ -0,0 +1,353 @@
+# bionemo-webdatamodule
+
+To install, execute the following:
+```bash
+pip install -e .
+```
+
+To run unit tests, execute:
+```bash
+pytest -v .
+```
+
+## WebDataModule
+
+```python
+class WebDataModule(L.LightningDataModule)
+```
+
+A LightningDataModule for using webdataset tar files to setup dataset and
+dataloader. This data module takes as input a dictionary: Split -> tar file
+directory and vaiours webdataset config settings. In its setup() function, it
+creates the webdataset object chaining up the input `pipeline_wds` workflow. In
+its train/val/test_dataloader(), it creates the WebLoader object chaining up the
+`pipeline_prebatch_wld` workflow
+
+Examples
+--------
+
+1. create the data module with input directory to webdataset tar files.
+Depending on which of the downstream Lightning.Trainer methods are called,
+e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or
+`Trainer.predict()`, only a subset of the train, val and test splits need to
+be specified in the various input options to the data module:
+
+- `Trainer.fit()` requires the `train` and `val` splits
+- `Trainer.validate()` requires the `val` split
+- `Trainer.test()` requires the `test` splits
+- `Trainer.predict()` requires the `test` splits
+
+Here is an example of constructing the data module for `Trainer.fit()`:
+```
+>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
+>>>
+>>> tar_file_prefix = "shards"
+>>>
+>>> dirs_of_tar_files = {
+>>> Split.train: "/path/to/train/split/tars",
+>>> Split.val: "/path/to/val/split/tars",
+>>> }
+>>>
+>>> n_samples {
+>>> Split.train: 1000,
+>>> Split.val: 100,
+>>> }
+>>>
+>>> # this is the string to retrieve the corresponding data object from the
+>>> # webdataset file (see
+>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
+>>> # for details)
+>>> suffix_keys_wds = "tensor.pyd"
+>>>
+>>> # see the API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+>>>
+>>> seed = 27193781
+>>>
+>>> # Specify the routines to process the samples in the WebDataset object.
+>>> # The routine is a generator of an Iterable of generators that are chained
+>>> # together by nested function calling. The following is equivalent of
+>>> # defining a overall generator of `shuffle(untuple(...))` which
+>>> # untuples the samples and shuffles them. See webdataset's Documentation
+>>> # for details.
+>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
+>>> # file parsing rule.
+>>>
+>>> untuple = lambda source : (sample for (sample,) in source)
+>>>
+>>> from webdatast import shuffle
+>>> pipeline_wds = {
+>>> Split.train : [untuple, shuffle(n_samples[Split.train],
+>>> rng=random.Random(seed_rng_shfl))],
+>>> Split.val: untuple
+>>> }
+>>>
+>>> # Similarly the user can optionally define the processing routine on the
+>>> # WebLoader (the dataloader of webdataset).
+>>> # NOTE: these routines by default take unbatched sample as input so the
+>>> # user can customize their batching routines here
+>>>
+>>> batch = batched(local_batch_size, collation_fn=lambda
+ list_samples : torch.vstack(list_samples))
+>>> pipeline_prebatch_wld = {
+ Split.train: [shuffle(n_samples[Split.train],
+ rng=random.Random(seed_rng_shfl)), batch],
+ Split.val : batch,
+ Split.test : batch
+ }
+>>>
+>>> # the user can optionally specify the kwargs for WebDataset and
+>>> # WebLoader
+>>>
+>>> kwargs_wds = {
+>>> split : {'shardshuffle' : split == Split.train,
+>>> 'nodesplitter' : wds.split_by_node,
+>>> 'seed' : seed_rng_shfl}
+>>> for split in Split
+>>> }
+>>>
+>>> kwargs_wld = {
+>>> split : {"num_workers": 2} for split in Split
+>>> }
+>>>
+>>> # construct the data module
+>>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds,
+ global_batch_size,
+ prefix_tars_wds=tar_file_prefix,
+ pipeline_wds=pipeline_wds,
+ pipeline_prebatch_wld=pipeline_prebatch_wld,
+ kwargs_wds=kwargs_wds,
+ kwargs_wld=kwargs_wld)
+```
+
+
+
+#### \_\_init\_\_
+
+```python
+def __init__(
+ dirs_tars_wds: Dict[Split, str],
+ n_samples: Dict[Split, int],
+ suffix_keys_wds: Union[str, Iterable[str]],
+ global_batch_size: int,
+ prefix_tars_wds: str = "wdshards",
+ pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
+ Iterable[Any]]]] = None,
+ pipeline_prebatch_wld: Optional[Dict[Split,
+ Union[Iterable[Iterable[Any]],
+ Iterable[Any]]]] = None,
+ kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
+ kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None)
+```
+
+constructor
+
+**Arguments**:
+
+- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file
+ directory that contains the webdataset tar files for each split
+- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of
+ data samples for each split
+- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each
+ corresponding to a data object in the webdataset tar file
+ dictionary. The data objects of these keys will be extracted and
+ tupled for each sample in the tar files
+- `global_batch_size` _int_ - size of batch summing across nodes in Data
+ Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE:
+ this data module doesn't rely on the input `global_batch_size`
+ for batching the samples. The batching is supposed to be done as
+ a part of the input `pipeline_prebatch_wld`. `global_batch_size`
+ is only used to compute a (pseudo-) epoch length for the data
+ loader so that the loader yield approximately n_samples //
+ global_batch_size batches
+ Kwargs:
+- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar
+ files. The input tar files are globbed by
+ "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar"
+ pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]],
+- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e.,
+ functor that maps a iterator to another iterator that
+ transforms the data sample yield from the dataset object, for
+ different splits, or an iterable to such a sequence of such
+ iterators. For example, this can be used to transform the
+ sample in the worker before sending it to the main process of
+ the dataloader
+ pipeline_prebatch_wld (Optional[Dict[Split,
+ Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary
+ of webloader composable, i.e., functor that maps a iterator to
+ another iterator that transforms the data sample yield from the
+ WebLoader object, for different splits, or an iterable to a
+ seuqnence of such iterators. For example, this can be used for
+ batching the samples. NOTE: this is applied before batching is
+ yield from the WebLoader
+- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the
+ WebDataset.__init__()
+- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the
+ WebLoader.__init__(), e.g., num_workers, of each split
+
+
+
+#### prepare\_data
+
+```python
+def prepare_data() -> None
+```
+
+This is called only by the main process by the Lightning workflow. Do
+not rely on this data module object's state update here as there is no
+way to communicate the state update to other subprocesses.
+
+Returns: None
+
+
+
+#### setup
+
+```python
+def setup(stage: str) -> None
+```
+
+This is called on all Lightning-managed nodes in a multi-node
+training session
+
+
+**Arguments**:
+
+- `stage` _str_ - "fit", "test" or "predict"
+- `Returns` - None
+
+## PickledDataWDS
+
+```python
+class PickledDataWDS(WebDataModule)
+```
+
+A LightningDataModule to process pickled data into webdataset tar files
+and setup dataset and dataloader. This inherits the webdataset setup from
+its parent module `WebDataModule`. This data module takes a directory of
+pickled data files, data filename prefixes for train/val/test splits, data
+filename suffixes and prepare webdataset tar files by globbing the specific
+pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and
+outputing to webdataset tar file with the dict structure:
+```
+ {"__key__" : name.replace(".", "-"),
+ suffix_pickles : pickled.dumps(data) }
+```
+NOTE: this assumes only one pickled file is processed for each sample. In
+its setup() function, it creates the webdataset object chaining up the input
+`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the
+WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
+
+Examples
+--------
+
+1. create the data module with a directory of pickle files and the file name
+prefix thereof for different splits to used by `Lightning.Trainer.fit()`
+
+```
+>>> from bionemo.webdatamodule.datamodule import Split, PickledDataWDS
+
+>>> dir_pickles = "/path/to/my/pickles/dir"
+
+>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
+>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
+>>> # validation dataset
+
+>>> suffix_pickles = "mydata.pt"
+
+>>> names_subset = {
+>>> Split.train: [sample1, sample2],
+>>> Split.val: [sample4, sample5],
+>>> }
+
+>>> # the following setting will attempt to create at least 5 tar files in
+>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
+
+>>> n_tars_wds = 5
+>>> prefix_tars_wds = "myshards"
+>>> output_dir_tar_files = "/path/to/output/tars/dir"
+
+>>> # see the `WebDataModule` API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+
+>>> # user can optionally customize the data processing routines and kwargs used
+>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
+
+>>> pipeline_wds = { Split.train: ... }
+
+>>> pipeline_prebatch_wld = { Split.train: ... }
+
+>>> kwargs_wds = { Split.train: ..., Split.val: ... }
+
+>>> kwargs_wld = { Split.train: ..., Split.val: ... }
+
+>>> # create the data module
+>>> data_module = PickledDataWDS(
+>>> dir_pickles,
+>>> suffix_pickles,
+>>> names_subset,
+>>> output_dir_tar_files,
+>>> global_batch_size, # `WebDataModule` args
+>>> n_tars_wds=n_tars_wds,
+>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
+>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
+>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
+>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
+>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
+>>> )
+
+```
+
+
+
+#### \_\_init\_\_
+
+```python
+def __init__(dir_pickles: str,
+ suffix_pickles: str,
+ names_subset: Dict[Split, List[str]],
+ prefix_dir_tars_wds: str,
+ *args,
+ n_tars_wds: Optional[int] = None,
+ **kwargs)
+```
+
+constructor
+
+**Arguments**:
+
+- `dir_pickles` _str_ - input directory of pickled data files
+- `suffix_pickles` _str_ - filename suffix of the input data in
+ dir_pickles. This is also used as the key mapped to the
+ tarballed pickled object in the webdataset
+- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of
+ the data samples to be loaded in the dataset and dataloader for
+ each of the split
+- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output
+ webdataset tar files. The actual directories storing the train, val
+ and test sets will be suffixed with "train", "val" and "test"
+ respectively.
+- `*args` - arguments passed to the parent WebDataModule
+
+ Kwargs:
+- `n_tars_wds` _int_ - attempt to create at least this number of
+ webdataset shards
+- `**kwargs` - arguments passed to the parent WebDataModule
+
+
+
+#### prepare\_data
+
+```python
+def prepare_data() -> None
+```
+
+This is called only by the main process by the Lightning workflow. Do
+not rely on this data module object's state update here as there is no
+way to communicate the state update to other subprocesses. The nesting
+`pickles_to_tars` function goes through the data name prefixes in the
+different splits, read the corresponding pickled file and output a
+webdataset tar archive with the dict structure: {"__key__" :
+name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.
+
+Returns: None
diff --git a/sub-packages/bionemo-webdatamodule/pyproject.toml b/sub-packages/bionemo-webdatamodule/pyproject.toml
new file mode 100644
index 0000000000..30a44da5c5
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/pyproject.toml
@@ -0,0 +1,32 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+# For guidance, see: https://packaging.python.org/en/latest/guides/writing-pyproject-toml/
+[project]
+name = "bionemo-webdatamodule"
+version = "0.0.1"
+authors = [
+ { name = "Dejun Lin", email = "dejunl@nvidia.com" },
+]
+description = ""
+readme = "README.md"
+requires-python = ">=3.10"
+keywords = []
+license = {file = "LICENSE"}
+classifiers = [
+ "Programming Language :: Python :: 3.10",
+ "Private :: Do Not Upload",
+]
+dynamic = ["dependencies"]
+
+[project.optional-dependencies]
+test = [
+ "pytest",
+]
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
+
+[tool.ruff]
+lint.ignore = ["C901", "E741", "E501", "E731"]
diff --git a/sub-packages/bionemo-webdatamodule/requirements.txt b/sub-packages/bionemo-webdatamodule/requirements.txt
new file mode 100644
index 0000000000..24ef528b0d
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/requirements.txt
@@ -0,0 +1 @@
+webdataset==0.2.96
diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py
new file mode 100644
index 0000000000..25e6abfbc5
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py
@@ -0,0 +1,14 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py
new file mode 100644
index 0000000000..33fa7936b9
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py
@@ -0,0 +1,491 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import glob
+from enum import Enum, auto
+from typing import Any, Dict, Iterable, List, Optional, Union, get_args
+
+import lightning as L
+import webdataset as wds
+
+from bionemo.webdatamodule.utils import pickles_to_tars
+
+
+class Split(Enum):
+ train = auto()
+ val = auto()
+ test = auto()
+
+
+class WebDataModule(L.LightningDataModule):
+ """A LightningDataModule for using webdataset tar files to setup dataset and
+ dataloader. This data module takes as input a dictionary: Split -> tar file
+ directory and vaiours webdataset config settings. In its setup() function,
+ it creates the webdataset object chaining up the input `pipeline_wds`
+ workflow. In its train/val/test_dataloader(), it creates the WebLoader
+ object chaining up the `pipeline_prebatch_wld` workflow
+
+ Examples
+ --------
+
+ 1. create the data module with input directory to webdataset tar files.
+ Depending on which of the downstream Lightning.Trainer methods are called,
+ e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or
+ `Trainer.predict()`, only a subset of the train, val and test splits need to
+ be specified in the various input options to the data module:
+
+ - `Trainer.fit()` requires the `train` and `val` splits
+ - `Trainer.validate()` requires the `val` split
+ - `Trainer.test()` requires the `test` splits
+ - `Trainer.predict()` requires the `test` splits
+
+ Here is an example of constructing the data module for `Trainer.fit()`:
+ ```
+ >>> from bionemo.core.data.datamodule import Split, WebDataModule
+ >>>
+ >>> tar_file_prefix = "shards"
+ >>>
+ >>> dirs_of_tar_files = {
+ >>> Split.train: "/path/to/train/split/tars",
+ >>> Split.val: "/path/to/val/split/tars",
+ >>> }
+ >>>
+ >>> n_samples {
+ >>> Split.train: 1000,
+ >>> Split.val: 100,
+ >>> }
+ >>>
+ >>> # this is the string to retrieve the corresponding data object from the
+ >>> # webdataset file (see
+ >>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
+ >>> # for details)
+ >>> suffix_keys_wds = "tensor.pyd"
+ >>>
+ >>> # see the API doc for the definition of global_batch_size
+ >>> global_batch_size = 16
+ >>>
+ >>> seed = 27193781
+ >>>
+ >>> # Specify the routines to process the samples in the WebDataset object.
+ >>> # The routine is a generator of an Iterable of generators that are chained
+ >>> # together by nested function calling. The following is equivalent of
+ >>> # defining a overall generator of `shuffle(untuple(...))` which
+ >>> # untuples the samples and shuffles them. See webdataset's Documentation
+ >>> # for details.
+ >>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
+ >>> # file parsing rule.
+ >>>
+ >>> untuple = lambda source : (sample for (sample,) in source)
+ >>>
+ >>> from webdatast import shuffle
+ >>> pipeline_wds = {
+ >>> Split.train : [untuple, shuffle(n_samples[Split.train],
+ >>> rng=random.Random(seed_rng_shfl))],
+ >>> Split.val: untuple
+ >>> }
+ >>>
+ >>> # Similarly the user can optionally define the processing routine on the
+ >>> # WebLoader (the dataloader of webdataset).
+ >>> # NOTE: these routines by default take unbatched sample as input so the
+ >>> # user can customize their batching routines here
+ >>>
+ >>> batch = batched(local_batch_size, collation_fn=lambda
+ list_samples : torch.vstack(list_samples))
+ >>> pipeline_prebatch_wld = {
+ Split.train: [shuffle(n_samples[Split.train],
+ rng=random.Random(seed_rng_shfl)), batch],
+ Split.val : batch,
+ Split.test : batch
+ }
+ >>>
+ >>> # the user can optionally specify the kwargs for WebDataset and
+ >>> # WebLoader
+ >>>
+ >>> kwargs_wds = {
+ >>> split : {'shardshuffle' : split == Split.train,
+ >>> 'nodesplitter' : wds.split_by_node,
+ >>> 'seed' : seed_rng_shfl}
+ >>> for split in Split
+ >>> }
+ >>>
+ >>> kwargs_wld = {
+ >>> split : {"num_workers": 2} for split in Split
+ >>> }
+ >>>
+ >>> # construct the data module
+ >>> data_module = WebDataModule(n_samples, suffix_keys_wds,
+ dirs_of_tar_files, global_batch_size,
+ prefix_tars_wds=tar_file_prefix,
+ pipeline_wds=pipeline_wds,
+ pipeline_prebatch_wld=pipeline_prebatch_wld,
+ kwargs_wds=kwargs_wds,
+ kwargs_wld=kwargs_wld)
+ ```
+
+ """
+
+ def __init__(
+ self,
+ n_samples: Dict[Split, int],
+ suffix_keys_wds: Union[str, Iterable[str]],
+ dirs_tars_wds: Dict[Split, str],
+ global_batch_size: int,
+ prefix_tars_wds: str = "wdshards",
+ pipeline_wds: Optional[
+ Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]
+ ] = None,
+ pipeline_prebatch_wld: Optional[
+ Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]
+ ] = None,
+ kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
+ kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
+ ):
+ """constructor
+
+ Args:
+ n_samples (Dict[Split, int]): input dictionary: Split -> number of
+ data samples for each split
+ suffix_keys_wds (Union[str, Iterable[str]]): a set of keys each
+ corresponding to a data object in the webdataset tar file
+ dictionary. The data objects of these keys will be extracted and
+ tupled for each sample in the tar files
+ dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file
+ directory that contains the webdataset tar files for each split
+ global_batch_size (int): size of batch summing across nodes in Data
+ Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE:
+ this data module doesn't rely on the input `global_batch_size`
+ for batching the samples. The batching is supposed to be done as
+ a part of the input `pipeline_prebatch_wld`. `global_batch_size`
+ is only used to compute a (pseudo-) epoch length for the data
+ loader so that the loader yield approximately n_samples //
+ global_batch_size batches
+ Kwargs:
+ prefix_tars_wds (str): name prefix of the input webdataset tar
+ files. The input tar files are globbed by
+ "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar"
+ pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]],
+ Iterable[Any]]]]): a dictionary of webdatast composable, i.e.,
+ functor that maps a iterator to another iterator that
+ transforms the data sample yield from the dataset object, for
+ different splits, or an iterable to such a sequence of such
+ iterators. For example, this can be used to transform the
+ sample in the worker before sending it to the main process of
+ the dataloader
+ pipeline_prebatch_wld (Optional[Dict[Split,
+ Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary
+ of webloader composable, i.e., functor that maps a iterator to
+ another iterator that transforms the data sample yield from the
+ WebLoader object, for different splits, or an iterable to a
+ seuqnence of such iterators. For example, this can be used for
+ batching the samples. NOTE: this is applied before batching is
+ yield from the WebLoader
+ kwargs_wds (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the
+ WebDataset.__init__()
+ kwargs_wld (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the
+ WebLoader.__init__(), e.g., num_workers, of each split
+
+
+ """
+ super().__init__()
+
+ self._dirs_tars_wds = dirs_tars_wds
+
+ keys_subset = self._dirs_tars_wds.keys()
+
+ if n_samples.keys() != keys_subset:
+ raise RuntimeError(
+ f"Input n_samples has different keys than "
+ f"dirs_tars_wds: {n_samples.keys()} vs "
+ f"{keys_subset}"
+ )
+
+ self._n_samples = n_samples
+
+ self._global_batch_size = global_batch_size
+
+ if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])):
+ raise TypeError("suffix_keys_wds can only be str or Iterable[str]")
+
+ self._suffix_keys_wds = suffix_keys_wds
+
+ self._prefix_tars_wds = prefix_tars_wds
+ self._pipeline_wds = pipeline_wds
+ self._pipeline_prebatch_wld = pipeline_prebatch_wld
+
+ self._kwargs_wld = kwargs_wld
+
+ self._kwargs_wds = kwargs_wds
+
+ # to be created later in setup
+ self._dataset = {}
+
+ def prepare_data(self) -> None:
+ """This is called only by the main process by the Lightning workflow. Do
+ not rely on this data module object's state update here as there is no
+ way to communicate the state update to other subprocesses.
+
+ Returns: None
+ """
+ pass
+
+ def _setup_wds(self, split: Split) -> wds.WebDataset:
+ """setup webdataset and webloader. This is called by setup()
+
+ Args:
+ split (Split): train, val or test split
+
+ Returns: WebDataset
+
+ """
+ if split not in self._dirs_tars_wds.keys():
+ raise RuntimeError(
+ f"_setup_wds() is called with {split} "
+ f"split that doesn't have the input tar dir"
+ )
+ urls = sorted(
+ glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar")
+ )
+ kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None
+ dataset = wds.WebDataset(
+ urls, **(kwargs if kwargs is not None else {})
+ ).decode()
+ if isinstance(self._suffix_keys_wds, str):
+ dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}")
+ else:
+ dataset = dataset.extract_keys(
+ *[f"*.{key}" for key in self._suffix_keys_wds]
+ )
+
+ if self._pipeline_wds is not None and self._pipeline_wds[split] is not None:
+ if isinstance(self._pipeline_wds[split], Iterable):
+ dataset = dataset.compose(*self._pipeline_wds[split])
+ else:
+ dataset = dataset.compose(self._pipeline_wds[split])
+ return dataset
+
+ def setup(self, stage: str) -> None:
+ """This is called on all Lightning-managed nodes in a multi-node
+ training session
+
+
+ Args:
+ stage (str): "fit", "test" or "predict"
+ Returns: None
+ """
+ if stage == "fit":
+ self._dataset[Split.train] = self._setup_wds(Split.train)
+ self._dataset[Split.val] = self._setup_wds(Split.val)
+ elif stage == "validate":
+ self._dataset[Split.val] = self._setup_wds(Split.val)
+ elif stage == "test":
+ self._dataset[Split.test] = self._setup_wds(Split.test)
+ elif stage == "predict":
+ self._dataset[Split.test] = self._setup_wds(Split.test)
+ else:
+ raise NotImplementedError(
+ f"Data setup with stage = {stage} " f"is not implmented"
+ )
+
+ def _setup_dataloader(self, split: Split) -> wds.WebLoader:
+ """setup the dataloader for the input dataset split
+
+ Args:
+ split (Split): input split type
+
+ Returns: WebLoader object
+
+ """
+ if self._dataset[split] is None:
+ raise RuntimeError(
+ f"_setup_dataloader() is called with {split} "
+ f"split without setting up the corresp. dataset"
+ )
+ dataset = self._dataset[split]
+ n_samples = self._n_samples[split]
+ n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size
+ kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None
+ loader = wds.WebLoader(
+ dataset, batch_size=None, **(kwargs if kwargs is not None else {})
+ )
+
+ if (
+ self._pipeline_prebatch_wld is not None
+ and self._pipeline_prebatch_wld[split] is not None
+ ):
+ if isinstance(self._pipeline_prebatch_wld[split], Iterable):
+ loader = loader.compose(*self._pipeline_prebatch_wld[split])
+ else:
+ loader = loader.compose(self._pipeline_prebatch_wld[split])
+
+ loader = loader.with_epoch(n_batches)
+
+ return loader
+
+ def train_dataloader(self) -> wds.WebLoader:
+ return self._setup_dataloader(Split.train)
+
+ def val_dataloader(self) -> wds.WebLoader:
+ return self._setup_dataloader(Split.val)
+
+ def test_dataloader(self) -> wds.WebLoader:
+ return self._setup_dataloader(Split.test)
+
+ def predict_dataloader(self) -> wds.WebLoader:
+ return self._setup_dataloader(Split.test)
+
+
+class PickledDataWDS(WebDataModule):
+ """A LightningDataModule to process pickled data into webdataset tar files
+ and setup dataset and dataloader. This inherits the webdataset setup from
+ its parent module `WebDataModule`. This data module takes a directory of
+ pickled data files, data filename prefixes for train/val/test splits, data
+ filename suffixes and prepare webdataset tar files by globbing the specific
+ pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and
+ outputing to webdataset tar file with the dict structure:
+ ```
+ {"__key__" : name.replace(".", "-"),
+ suffix_pickles : pickled.dumps(data) }
+ ```
+ NOTE: this assumes only one pickled file is processed for each sample. In
+ its setup() function, it creates the webdataset object chaining up the input
+ `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the
+ WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
+
+ Examples
+ --------
+
+ 1. create the data module with a directory of pickle files and the file name
+ prefix thereof for different splits to used by `Lightning.Trainer.fit()`
+
+ ```
+ >>> from bionemo.core.data.datamodule import Split, PickledDataWDS
+
+ >>> dir_pickles = "/path/to/my/pickles/dir"
+
+ >>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
+ >>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
+ >>> # validation dataset
+
+ >>> suffix_pickles = "mydata.pt"
+
+ >>> names_subset = {
+ >>> Split.train: [sample1, sample2],
+ >>> Split.val: [sample4, sample5],
+ >>> }
+
+ >>> # the following setting will attempt to create at least 5 tar files in
+ >>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
+
+ >>> n_tars_wds = 5
+ >>> prefix_tars_wds = "myshards"
+ >>> output_dir_tar_files = {
+ Split.train : "/path/to/output/tars/dir-train",
+ Split.val : "/path/to/output/tars/dir-val",
+ Split.test : "/path/to/output/tars/dir-test",
+ }
+
+ >>> # see the `WebDataModule` API doc for the definition of global_batch_size
+ >>> global_batch_size = 16
+
+ >>> # user can optionally customize the data processing routines and kwargs used
+ >>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
+
+ >>> pipeline_wds = { Split.train: ... }
+
+ >>> pipeline_prebatch_wld = { Split.train: ... }
+
+ >>> kwargs_wds = { Split.train: ..., Split.val: ... }
+
+ >>> kwargs_wld = { Split.train: ..., Split.val: ... }
+
+ >>> # create the data module
+ >>> data_module = PickledDataWDS(
+ >>> dir_pickles,
+ >>> names_subset,
+ >>> suffix_pickles, # `WebDataModule` args
+ >>> output_dir_tar_files, # `WebDataModule` args
+ >>> global_batch_size, # `WebDataModule` args
+ >>> n_tars_wds=n_tars_wds,
+ >>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
+ >>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
+ >>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
+ >>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
+ >>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
+ >>> )
+
+ ```
+
+ """
+
+ def __init__(
+ self,
+ dir_pickles: str,
+ names_subset: Dict[Split, List[str]],
+ *args,
+ n_tars_wds: Optional[int] = None,
+ **kwargs,
+ ):
+ """constructor
+
+ Args:
+ dir_pickles (str): input directory of pickled data files
+ names_subset (Dict[Split, List[str]]): list of filename prefix of
+ the data samples to be loaded in the dataset and dataloader for
+ each of the split
+ *args: arguments passed to the parent WebDataModule after its
+ `n_samples` args (where `n_samples` is deduced from the length of
+ `names_subset` arg of this class)
+
+ Kwargs:
+ n_tars_wds (int): attempt to create at least this number of
+ webdataset shards
+ **kwargs: arguments passed to the parent WebDataModule
+
+
+ """
+ super().__init__(
+ {split: len(names_subset[split]) for split in names_subset.keys()},
+ *args,
+ **kwargs,
+ )
+
+ self._dir_pickles = dir_pickles
+
+ self._names_subset = names_subset
+
+ self._n_tars_wds = n_tars_wds
+
+ def prepare_data(self) -> None:
+ """This is called only by the main process by the Lightning workflow. Do
+ not rely on this data module object's state update here as there is no
+ way to communicate the state update to other subprocesses. The nesting
+ `pickles_to_tars` function goes through the data name prefixes in the
+ different splits, read the corresponding pickled file and output a
+ webdataset tar archive with the dict structure: {"__key__" :
+ name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.
+
+ Returns: None
+ """
+ for split in self._names_subset.keys():
+ # create wds shards (tar files) for train set
+ pickles_to_tars(
+ self._dir_pickles,
+ self._names_subset[split],
+ self._suffix_keys_wds,
+ self._dirs_tars_wds[split],
+ self._prefix_tars_wds,
+ min_num_shards=self._n_tars_wds,
+ )
diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py
new file mode 100644
index 0000000000..541957edd7
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import pickle
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_args
+
+import webdataset as wds
+from nemo.utils import logging
+
+
+def pickles_to_tars(
+ dir_input: str,
+ input_prefix_subset: List[str],
+ input_suffix: Union[str, Iterable[str]],
+ dir_output: str,
+ output_prefix: str,
+ func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix,
+ suffix_to_data: {"__key__": prefix, **suffix_to_data},
+ min_num_shards: Optional[int] = None,
+) -> None:
+ """Convert a subset of pickle files from a directory to Webdataset tar files
+ Input path and name pattern for sample 0:
+ f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}"
+ f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}"
+ Input path and name pattern for sample 1:
+ f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}"
+ f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}"
+ ...
+ Output path and name pattern:
+ f"{dir_output}/{output_prefix}-%06d.tar"
+
+ The webdataset tar archive is specified by the dictionary:
+ {
+ "__key__" : sample_filename_preifx,
+ sample_filename_suffix_1 : data_1,
+ sample_filename_suffix_2 : data_2,
+ ...
+ }
+ so that parsing the tar archive is equivalent of reading
+ {sample_filename_preifx}.{sample_filename_suffix_1} etc.
+
+ Here, each sample data get its name prefix from one element of
+ `input_prefix_subset` and its name suffixes from the list `input_suffix`.
+ Per the webdataset file format specification, the `sample_filename_preifx`
+ can't contain dots '.' so this function removes it for the user by calling
+ .replace(".", "-") on the elements of `input_prefix_subset`
+
+ Args:
+ dir_input (str): Input directory
+ input_prefix_subset (List[str]): Input subset of pickle files' prefix
+ input_suffix (Union[str, Iterable[str]]): Input pickle file name
+ suffixes, each for one type of data object, for all the samples
+ dir_output (str): Output directory
+ output_prefix (str): Output tar file name prefix
+ func_output_data (Callable[[str, Dict[str, Any]], Dict[str, Any]]) :
+ function that maps the name prefix, name suffix and data object to a
+ webdataset tar archive dictionary. Refer to the webdataset github
+ repo for the archive file format specification.
+ min_num_shards (int) : create at least this number of tar files.
+ WebDataset has bugs when reading small number of tar files in a
+ multi-node lightening + DDP setting so this option can be used to
+ guarantee the tar file counts
+
+ Returns: None
+
+ """
+ if not isinstance(input_suffix, get_args(Union[str, Iterable])):
+ raise TypeError("input_suffix can only be str or Iterable[str]")
+ os.makedirs(dir_output, exist_ok=True)
+ wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar")
+ n_samples_per_shard_max = 100000
+ if min_num_shards is not None:
+ if min_num_shards <= 0:
+ raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0")
+ n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards
+ with wds.ShardWriter(
+ wd_subset_pattern,
+ encoder=False,
+ maxcount=n_samples_per_shard_max,
+ compress=False,
+ mode=0o777,
+ ) as sink:
+ for name in input_prefix_subset:
+ try:
+ if isinstance(input_suffix, str):
+ suffix_to_data = {
+ input_suffix: pickle.dumps(
+ pickle.loads(
+ (
+ Path(dir_input) / f"{name}.{input_suffix}"
+ ).read_bytes()
+ )
+ )
+ }
+ else:
+ suffix_to_data = {
+ suffix: pickle.dumps(
+ pickle.loads(
+ (Path(dir_input) / f"{name}.{suffix}").read_bytes()
+ )
+ )
+ for suffix in input_suffix
+ }
+ # the prefix name shouldn't contain any "." per webdataset's
+ # specification
+ sample = func_output_data(name.replace(".", "-"), suffix_to_data)
+ sink.write(sample)
+ except ModuleNotFoundError as e:
+ logging.error(
+ f"Dependency for parsing input pickle data not found: {e}"
+ )
+ raise e
+ except Exception as e:
+ logging.error(f"Failed to write {name} into tar files due to error {e}")
+ raise e
diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py
new file mode 100644
index 0000000000..a43f4c0be3
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py
@@ -0,0 +1,301 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pickle
+import random
+
+import lightning as L
+import pytest
+import torch
+import webdataset as wds
+from webdataset.filters import batched, shuffle
+
+from bionemo.webdatamodule.datamodule import PickledDataWDS, Split, WebDataModule
+from bionemo.webdatamodule.utils import pickles_to_tars
+
+
+@pytest.fixture(scope="module")
+def gen_pickle_files(tmp_path_factory):
+ dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix()
+ prefix_sample = "sample"
+ suffix_sample = ["tensor.pyd", "tensor_copy.pyd"]
+ n_samples_per_split = 10
+ prefixes = []
+ # generate the pickles for train, val, and test
+ for i in range(n_samples_per_split * 3):
+ prefix = f"{prefix_sample}-{i:04}"
+ prefixes.append(prefix)
+ t = torch.tensor(i, dtype=torch.int32)
+ for suffix in suffix_sample:
+ with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh:
+ pickle.dump(t, fh)
+ prefixes_pickle = {
+ Split.train: prefixes[0:n_samples_per_split],
+ Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2],
+ Split.test: prefixes[n_samples_per_split * 2 : n_samples_per_split * 3],
+ }
+ return (
+ dir_pickles,
+ prefix_sample,
+ suffix_sample,
+ prefixes_pickle,
+ n_samples_per_split,
+ )
+
+
+@pytest.fixture(scope="module", params=[1, 2])
+def gen_test_data(tmp_path_factory, gen_pickle_files, request):
+ dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = (
+ gen_pickle_files
+ )
+ n_suffixes = request.param
+ if n_suffixes <= 1:
+ suffix_sample = suffixes[0]
+ else:
+ suffix_sample = suffixes[0:n_suffixes]
+ dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix()
+ dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split}
+ prefix_tar = "tensor"
+ n_samples = {split: n_samples_per_split for split in Split}
+ # generate the tars
+ pickles_to_tars(
+ dir_pickles,
+ prefixes_pickle[Split.train],
+ suffix_sample,
+ dir_tars[Split.train],
+ prefix_tar,
+ min_num_shards=3,
+ )
+ pickles_to_tars(
+ dir_pickles,
+ prefixes_pickle[Split.val],
+ suffix_sample,
+ dir_tars[Split.val],
+ prefix_tar,
+ min_num_shards=3,
+ )
+ pickles_to_tars(
+ dir_pickles,
+ prefixes_pickle[Split.test],
+ suffix_sample,
+ dir_tars[Split.test],
+ prefix_tar,
+ min_num_shards=3,
+ )
+ return (
+ dir_pickles,
+ dir_tars,
+ prefix_sample,
+ suffix_sample,
+ prefix_tar,
+ n_samples,
+ prefixes_pickle,
+ )
+
+
+def _create_webdatamodule(gen_test_data, num_workers=2):
+ (_, dirs_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples, _) = (
+ gen_test_data
+ )
+ local_batch_size = 2
+ global_batch_size = 2
+ seed_rng_shfl = 82838392
+
+ batch = batched(
+ local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)
+ )
+
+ if isinstance(suffix_keys_wds, str):
+ untuple = lambda source: (sample[0] for sample in source)
+ elif isinstance(suffix_keys_wds, list):
+ untuple = lambda source: (torch.vstack(sample) for sample in source)
+
+ pipeline_wds = {
+ Split.train: [
+ untuple,
+ shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)),
+ ],
+ Split.val: untuple,
+ Split.test: untuple,
+ }
+
+ pipeline_prebatch_wld = {
+ Split.train: [
+ shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)),
+ batch,
+ ],
+ Split.val: batch,
+ Split.test: batch,
+ }
+
+ kwargs_wds = {
+ split: {
+ "shardshuffle": split == Split.train,
+ "nodesplitter": wds.split_by_node,
+ "seed": seed_rng_shfl,
+ }
+ for split in Split
+ }
+
+ kwargs_wld = {split: {"num_workers": num_workers} for split in Split}
+
+ data_module = WebDataModule(
+ n_samples,
+ suffix_keys_wds,
+ dirs_tars_wds,
+ global_batch_size,
+ prefix_tars_wds=prefix_tars_wds,
+ pipeline_wds=pipeline_wds,
+ pipeline_prebatch_wld=pipeline_prebatch_wld,
+ kwargs_wds=kwargs_wds,
+ kwargs_wld=kwargs_wld,
+ )
+
+ return data_module, dirs_tars_wds
+
+
+@pytest.fixture(scope="module")
+def create_webdatamodule(gen_test_data):
+ return _create_webdatamodule(gen_test_data)
+
+
+@pytest.fixture(scope="module")
+def create_another_webdatamodule(gen_test_data):
+ return _create_webdatamodule(gen_test_data)
+
+
+@pytest.fixture(scope="module")
+def create_webdatamodule_with_5_workers(gen_test_data):
+ return _create_webdatamodule(gen_test_data, num_workers=5)
+
+
+class ModelTestWebDataModule(L.LightningModule):
+ def __init__(self) -> None:
+ super().__init__()
+ self._model = torch.nn.Linear(1, 1)
+ self._samples = {split: [] for split in Split}
+
+ def forward(self, x):
+ return self._model(x.float())
+
+ def training_step(self, batch):
+ self._samples[Split.train].append(batch)
+ loss = self(batch).sum()
+ return loss
+
+ def validation_step(self, batch, batch_index):
+ self._samples[Split.val].append(batch)
+ return torch.zeros(1)
+
+ def test_step(self, batch, batch_index):
+ self._samples[Split.test].append(batch)
+
+ def predict_step(self, batch, batch_index):
+ self._samples[Split.test].append(batch)
+ return torch.zeros(1)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
+ return optimizer
+
+
+@pytest.fixture(scope="function")
+def create_trainer_and_model():
+ trainer = L.Trainer(
+ max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1
+ )
+ model = ModelTestWebDataModule()
+ return trainer, model
+
+
+def _create_pickleddatawds(tmp_path_factory, gen_test_data):
+ (
+ dir_pickles,
+ _,
+ _,
+ suffix_keys_wds,
+ prefix_tars_wds,
+ n_samples,
+ names,
+ ) = gen_test_data
+ local_batch_size = 2
+ global_batch_size = 2
+ seed_rng_shfl = 82838392
+ n_tars_wds = 3
+
+ prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix()
+ dirs_tars_wds = {s: f"{prefix_dir_tars_wds}{str(s).split('.')[-1]}" for s in Split}
+
+ batch = batched(
+ local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)
+ )
+
+ untuple = lambda source: (sample[0] for sample in source)
+
+ pipeline_wds = {
+ Split.train: [
+ untuple,
+ shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)),
+ ],
+ Split.val: untuple,
+ Split.test: untuple,
+ }
+
+ pipeline_prebatch_wld = {
+ Split.train: [
+ shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)),
+ batch,
+ ],
+ Split.val: batch,
+ Split.test: batch,
+ }
+
+ kwargs_wds = {
+ split: {
+ "shardshuffle": split == Split.train,
+ "nodesplitter": wds.split_by_node,
+ "seed": seed_rng_shfl,
+ }
+ for split in Split
+ }
+
+ kwargs_wld = {split: {"num_workers": 2} for split in Split}
+
+ data_module = PickledDataWDS(
+ dir_pickles,
+ names,
+ suffix_keys_wds,
+ dirs_tars_wds,
+ global_batch_size,
+ n_tars_wds=n_tars_wds,
+ prefix_tars_wds=prefix_tars_wds,
+ pipeline_wds=pipeline_wds,
+ pipeline_prebatch_wld=pipeline_prebatch_wld,
+ kwargs_wds=kwargs_wds,
+ kwargs_wld=kwargs_wld,
+ )
+
+ return data_module, dirs_tars_wds, n_tars_wds
+
+
+@pytest.fixture(scope="module")
+def create_pickleddatawds(tmp_path_factory, gen_test_data):
+ return _create_pickleddatawds(tmp_path_factory, gen_test_data)
+
+
+@pytest.fixture(scope="module")
+def create_another_pickleddatawds(tmp_path_factory, gen_test_data):
+ return _create_pickleddatawds(tmp_path_factory, gen_test_data)
diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py
new file mode 100644
index 0000000000..692905a416
--- /dev/null
+++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+from enum import Enum, auto
+
+import lightning as L
+import pytest
+import torch
+
+from bionemo.webdatamodule.datamodule import Split
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_webdatamodule_init(split, create_webdatamodule):
+ data_module, dirs_tars_wds = create_webdatamodule
+ assert data_module._n_samples[split] == 10, (
+ f"Wrong {split}-set size: "
+ f"expected 10 "
+ f"but got {data_module._n_samples[split]}"
+ )
+ assert data_module._dirs_tars_wds[split] == f"{dirs_tars_wds[split]}", (
+ f"Wrong tar files directory: "
+ f"expected {dirs_tars_wds[split]} "
+ f"but got {data_module._dirs_tars_wds[split]}"
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_webdatamodule_setup_dataset(
+ split, create_webdatamodule, create_another_webdatamodule
+):
+ data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]]
+ lists_tensors = []
+ for m in data_modules:
+ m.prepare_data()
+ # run through all the possible stages first to setup all the correps.
+ # dataset objects
+ m.setup("fit")
+ m.setup("test")
+ L.seed_everything(2823828)
+ tensors = []
+ for sample in m._dataset[split]:
+ assert isinstance(
+ sample, torch.Tensor
+ ), "Sample yield from dataset is not tensor"
+ tensors.append(sample)
+ lists_tensors.append(tensors)
+
+ assert len(lists_tensors[0]) > 0, "No names in {split} dataset"
+ torch.testing.assert_close(
+ torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_webdatamodule_setup_dataloader(
+ split, create_webdatamodule, create_another_webdatamodule
+):
+ data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]]
+ lists_tensors = []
+ for m in data_modules:
+ m.prepare_data()
+ # run through all the possible stages first to setup all the correps.
+ # dataset objects
+ m.setup("fit")
+ m.setup("test")
+ L.seed_everything(2823828)
+ tensors = []
+ loader = None
+ if split == Split.train:
+ loader = m.train_dataloader()
+ elif split == Split.val:
+ loader = m.val_dataloader()
+ elif split == Split.test:
+ loader = m.test_dataloader()
+ else:
+ raise RuntimeError(f"Test for split {split} not implemented")
+ assert loader is not None, "dataloader not instantated"
+ for samples in loader:
+ # PyG's HeteroDataBatch is Batch inherited from HeteroData
+ assert isinstance(
+ samples, torch.Tensor
+ ), "Sample object is not torch.Tensor"
+ tensors.append(samples)
+ lists_tensors.append(tensors)
+
+ assert len(lists_tensors[0]) > 0, "No names in {split} dataloader"
+ torch.testing.assert_close(
+ torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_webdatamodule_throw_on_many_workers(
+ split, create_webdatamodule_with_5_workers
+):
+ data_module = create_webdatamodule_with_5_workers[0]
+ urls = glob.glob(
+ f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar"
+ )
+ n_tars = len(urls)
+ data_module._kwargs_wld[split]["num_workers"] = n_tars + 1
+ data_module.prepare_data()
+ data_module.setup("fit")
+ data_module.setup("test")
+ loader = None
+ if split == Split.train:
+ loader = data_module.train_dataloader()
+ elif split == Split.val:
+ loader = data_module.val_dataloader()
+ elif split == Split.test:
+ loader = data_module.test_dataloader()
+ else:
+ raise RuntimeError(f"Test for split {split} not implemented")
+ assert loader is not None, "dataloader not instantated"
+ try:
+ for _ in loader:
+ pass
+ except ValueError as e:
+ # this is expected
+ assert "have fewer shards than workers" in str(e), (
+ f"'have fewer shards than workers' not found in exception "
+ f"raised from data loading: {e}"
+ )
+ except Exception as e:
+ raise RuntimeError(
+ f"WebLoader doesn't raise ValueError with fewer "
+ f"shards than workers but raise this instead: {e}"
+ )
+ else:
+ raise NotImplementedError(
+ "WebLoader doesn't throw error with num_workers > num_shards "
+ "User should report this issue to webdataset and create "
+ "less shards than workers in practice as a workaround"
+ )
+
+
+class Stage(Enum):
+ fit = auto()
+ validate = auto()
+ test = auto()
+ predict = auto()
+
+
+@pytest.mark.parametrize("stage", list(Stage))
+def test_webdatamodule_in_lightning(
+ stage, create_webdatamodule, create_another_webdatamodule, create_trainer_and_model
+):
+ data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]]
+ trainer, model = create_trainer_and_model
+ # get the list of samples from the loader
+ L.seed_everything(2823828)
+ data_modules[0].prepare_data()
+ split = None
+ if stage == Stage.fit:
+ split = Split.train
+ elif stage == Stage.validate:
+ split = Split.val
+ elif stage == Stage.test or stage == Stage.predict:
+ split = Split.test
+ else:
+ raise RuntimeError(f"{stage} stage not implemented")
+ name_stage = str(stage).split(".")[-1]
+ data_modules[0].setup(name_stage)
+ # get the list of samples from the workflow
+ get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader")
+ loader = get_dataloader()
+ L.seed_everything(2823828)
+ workflow = getattr(trainer, name_stage)
+ workflow(model, data_modules[1])
+ device = model._samples[split][0].device
+ samples = [sample.to(device=device) for sample in loader]
+ torch.testing.assert_close(
+ torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0)
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_pickleddatawds_init(split, create_pickleddatawds):
+ data_module, dirs_tars_wds, _ = create_pickleddatawds
+ assert data_module._n_samples[split] == 10, (
+ f"Wrong {split}-set size: "
+ f"expected 10 "
+ f"but got {data_module._n_samples[split]}"
+ )
+ assert data_module._dirs_tars_wds[split] == dirs_tars_wds[split], (
+ f"Wrong tar files directory: "
+ f"expected {dirs_tars_wds[split]} "
+ f"but got {data_module._dirs_tars_wds[split]}"
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_pickleddatawds_prepare_data(split, create_pickleddatawds):
+ data_module, _, n_tars_min = create_pickleddatawds
+ data_module.prepare_data()
+ dir_tars = f"{data_module._dirs_tars_wds[split]}"
+ tars = glob.glob(f"{dir_tars}/{data_module._prefix_tars_wds}-*.tar")
+ n_tars = len(tars)
+ assert n_tars_min <= n_tars and n_tars <= n_tars_min + 1, (
+ f"Number of tar files: {n_tars} in {dir_tars} is outside the range "
+ f"[{n_tars_min}, {n_tars_min + 1}]"
+ )
+
+
+@pytest.mark.parametrize("split", list(Split))
+def test_pickleddatawds_setup_dataset(
+ split, create_pickleddatawds, create_another_pickleddatawds
+):
+ data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]]
+ lists_tensors = []
+ for m in data_modules:
+ m.prepare_data()
+ # run through all the possible stages first to setup all the correps.
+ # dataset objects
+ m.setup("fit")
+ m.setup("test")
+ L.seed_everything(2823828)
+ tensors = []
+ for sample in m._dataset[split]:
+ assert isinstance(
+ sample, torch.Tensor
+ ), "Sample yield from dataset is not tensor"
+ tensors.append(sample)
+ lists_tensors.append(tensors)
+
+ assert len(lists_tensors[0]) > 0, "No names in {split} dataset"
+ torch.testing.assert_close(
+ torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1])
+ )
+
+
+def test_pickleddatawds_sample_overlap(create_pickleddatawds):
+ data_module = create_pickleddatawds[0]
+ # this writes the tar files to disk
+ data_module.prepare_data()
+ # read the data back by setting up the dataset object and loop over it
+ data_module.setup("fit")
+ data_module.setup("test")
+ results = {
+ split: set([sample.item() for sample in data_module._dataset[split]])
+ for split in Split
+ }
+ overlap_train_val = results[Split.train] & results[Split.val]
+ overlap_train_test = results[Split.train] & results[Split.test]
+ overlap_val_test = results[Split.val] & results[Split.test]
+ assert (
+ len(overlap_train_val) == 0
+ ), "Shared samples found between train and val datasets"
+ assert (
+ len(overlap_train_test) == 0
+ ), "Shared samples found between train and test datasets"
+ assert (
+ len(overlap_val_test) == 0
+ ), "Shared samples found between val and test datasets"