Skip to content

Commit

Permalink
Test: multiple data objects per sample
Browse files Browse the repository at this point in the history
  • Loading branch information
DejunL committed Aug 29, 2024
1 parent 0d5c144 commit a35765f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@
def gen_pickle_files(tmp_path_factory):
dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix()
prefix_sample = "sample"
suffix_sample = "tensor.pyd"
suffix_sample = ["tensor.pyd", "tensor_copy.pyd"]
n_samples_per_split = 10
prefixes = []
# generate the pickles for train, val, and test
for i in range(n_samples_per_split * 3):
prefix = f"{prefix_sample}-{i:04}"
prefixes.append(prefix)
t = torch.tensor(i, dtype=torch.int32)
with open(f"{dir_pickles}/{prefix}.{suffix_sample}", "wb") as fh:
pickle.dump(t, fh)
for suffix in suffix_sample:
with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh:
pickle.dump(t, fh)
prefixes_pickle = {
Split.train: prefixes[0:n_samples_per_split],
Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2],
Expand All @@ -55,11 +56,16 @@ def gen_pickle_files(tmp_path_factory):
)


@pytest.fixture(scope="module")
def gen_test_data(tmp_path_factory, gen_pickle_files):
dir_pickles, prefix_sample, suffix_sample, prefixes_pickle, n_samples_per_split = (
@pytest.fixture(scope="module", params=[1, 2])
def gen_test_data(tmp_path_factory, gen_pickle_files, request):
dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = (
gen_pickle_files
)
n_suffixes = request.param
if n_suffixes <= 1:
suffix_sample = suffixes[0]
else:
suffix_sample = suffixes[0:n_suffixes]
dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix()
dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split}
prefix_tar = "tensor"
Expand Down Expand Up @@ -112,7 +118,10 @@ def _create_webdatamodule(gen_test_data, num_workers=2):
local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)
)

untuple = lambda source: (sample for (sample,) in source)
if isinstance(suffix_keys_wds, str):
untuple = lambda source: (sample[0] for sample in source)
elif isinstance(suffix_keys_wds, list):
untuple = lambda source: (torch.vstack(sample) for sample in source)

pipeline_wds = {
Split.train: [
Expand Down Expand Up @@ -183,19 +192,19 @@ def forward(self, x):
return self._model(x.float())

def training_step(self, batch):
self._samples[Split.train].append(batch.name)
self._samples[Split.train].append(batch)
loss = self(batch).sum()
return loss

def validation_step(self, batch, batch_index):
self._samples[Split.val].append(batch.name)
self._samples[Split.val].append(batch)
return torch.zeros(1)

def test_step(self, batch, batch_index):
self._samples[Split.test].append(batch.name)
self._samples[Split.test].append(batch)

def predict_step(self, batch, batch_index):
self._samples[Split.test].append(batch.name)
self._samples[Split.test].append(batch)
return torch.zeros(1)

def configure_optimizers(self):
Expand Down Expand Up @@ -234,7 +243,7 @@ def _create_pickleddatawds(tmp_path_factory, gen_test_data):
local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples)
)

untuple = lambda source: (sample for (sample,) in source)
untuple = lambda source: (sample[0] for sample in source)

pipeline_wds = {
Split.train: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,14 @@ def test_webdatamodule_in_lightning(
# get the list of samples from the workflow
get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader")
loader = get_dataloader()
samples = [sample.name for sample in loader]
L.seed_everything(2823828)
workflow = getattr(trainer, name_stage)
workflow(model, data_modules[1])
assert model._samples[split] == samples
device = model._samples[split][0].device
samples = [sample.to(device=device) for sample in loader]
torch.testing.assert_close(
torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0)
)


@pytest.mark.parametrize("split", list(Split))
Expand Down

0 comments on commit a35765f

Please sign in to comment.