diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py index 5551d6efe4..a43f4c0be3 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -31,7 +31,7 @@ def gen_pickle_files(tmp_path_factory): dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() prefix_sample = "sample" - suffix_sample = "tensor.pyd" + suffix_sample = ["tensor.pyd", "tensor_copy.pyd"] n_samples_per_split = 10 prefixes = [] # generate the pickles for train, val, and test @@ -39,8 +39,9 @@ def gen_pickle_files(tmp_path_factory): prefix = f"{prefix_sample}-{i:04}" prefixes.append(prefix) t = torch.tensor(i, dtype=torch.int32) - with open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb") as fh: - pickle.dump(t, fh) + for suffix in suffix_sample: + with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh: + pickle.dump(t, fh) prefixes_pickle = { Split.train: prefixes[0:n_samples_per_split], Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], @@ -55,11 +56,16 @@ def gen_pickle_files(tmp_path_factory): ) -@pytest.fixture(scope="module") -def gen_test_data(tmp_path_factory, gen_pickle_files): - dir_pickles, prefix_sample, suffix_sample, prefixes_pickle, n_samples_per_split = ( +@pytest.fixture(scope="module", params=[1, 2]) +def gen_test_data(tmp_path_factory, gen_pickle_files, request): + dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = ( gen_pickle_files ) + n_suffixes = request.param + if n_suffixes <= 1: + suffix_sample = suffixes[0] + else: + suffix_sample = suffixes[0:n_suffixes] dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} prefix_tar = "tensor" @@ -112,7 +118,10 @@ def _create_webdatamodule(gen_test_data, num_workers=2): local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) - untuple = lambda source: (sample for (sample,) in source) + if isinstance(suffix_keys_wds, str): + untuple = lambda source: (sample[0] for sample in source) + elif isinstance(suffix_keys_wds, list): + untuple = lambda source: (torch.vstack(sample) for sample in source) pipeline_wds = { Split.train: [ @@ -183,19 +192,19 @@ def forward(self, x): return self._model(x.float()) def training_step(self, batch): - self._samples[Split.train].append(batch.name) + self._samples[Split.train].append(batch) loss = self(batch).sum() return loss def validation_step(self, batch, batch_index): - self._samples[Split.val].append(batch.name) + self._samples[Split.val].append(batch) return torch.zeros(1) def test_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) + self._samples[Split.test].append(batch) def predict_step(self, batch, batch_index): - self._samples[Split.test].append(batch.name) + self._samples[Split.test].append(batch) return torch.zeros(1) def configure_optimizers(self): @@ -234,7 +243,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data): local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) ) - untuple = lambda source: (sample for (sample,) in source) + untuple = lambda source: (sample[0] for sample in source) pipeline_wds = { Split.train: [ diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py index fd5831585b..692905a416 100644 --- a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py @@ -178,11 +178,14 @@ def test_webdatamodule_in_lightning( # get the list of samples from the workflow get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") loader = get_dataloader() - samples = [sample.name for sample in loader] L.seed_everything(2823828) workflow = getattr(trainer, name_stage) workflow(model, data_modules[1]) - assert model._samples[split] == samples + device = model._samples[split][0].device + samples = [sample.to(device=device) for sample in loader] + torch.testing.assert_close( + torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0) + ) @pytest.mark.parametrize("split", list(Split))