Skip to content

Commit

Permalink
implement aggregate_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 24, 2024
1 parent 2e97688 commit 5450b90
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
28 changes: 28 additions & 0 deletions hrdae/dataloaders/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Callable

from omegaconf import MISSING
from torch import cat, Tensor
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

Expand All @@ -16,6 +18,26 @@ class BasicDataLoaderOption(DataLoaderOption):
transform_order_train: list[str] = MISSING
transform_order_val: list[str] = MISSING

aggregate_dim: int = -1


def get_aggregating_collate_fn(
aggregate_dim: int,
) -> Callable[[list[dict[str, Tensor]]], dict[str, Tensor]]:
def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Tensor]:
old: dict[str, list[Tensor]] = {}
for item in batch:
for k_item, v_item in item.items():
if k_item not in old:
old[k_item] = []
old[k_item].append(v_item)
new: dict[str, Tensor] = {}
for k_old, v_old in old.items():
new[k_old] = cat(v_old, dim=aggregate_dim)
return new

return collate_fn


def create_basic_dataloader(
opt: BasicDataLoaderOption,
Expand All @@ -26,6 +48,9 @@ def create_basic_dataloader(
[create_transform(opt.transform[name]) for name in transform_order]
)
dataset = create_dataset(opt.dataset, transform, is_train)
collate_fn = None
if opt.aggregate_dim >= 0:
collate_fn = get_aggregating_collate_fn(opt.aggregate_dim)

if is_train:
train_size = int(opt.train_val_ratio * len(dataset)) # type: ignore
Expand All @@ -38,11 +63,13 @@ def create_basic_dataloader(
train_dataset,
batch_size=opt.batch_size,
shuffle=is_train,
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=opt.batch_size,
shuffle=is_train,
collate_fn=collate_fn,
)
return train_loader, val_loader

Expand All @@ -51,6 +78,7 @@ def create_basic_dataloader(
dataset,
batch_size=opt.batch_size,
shuffle=is_train,
collate_fn=collate_fn,
),
None,
)
2 changes: 1 addition & 1 deletion hrdae/dataloaders/datasets/seq_serialize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(
self.base = base

def __len__(self) -> int:
return len(self.base)
return len(self.base) # type: ignore

def __getitem__(self, index: int) -> dict[str, Tensor]:
x = self.base[index]["xp"]
Expand Down

0 comments on commit 5450b90

Please sign in to comment.