Skip to content

Commit

Permalink
impl tuning/ct
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent 75079f8 commit bb219d7
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/dataloader/basic.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
batch_size: 64
batch_size: 2
train_val_ratio: 0.8
# # mnist
# transform_order_train:
Expand Down
9 changes: 8 additions & 1 deletion hrdae/dataloaders/datasets/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@dataclass
class CTDatasetOption(DatasetOption):
root: Path = MISSING
slice_index: list[int] = MISSING
threshold: float = 0.1
min_occupancy: float = 0.2
in_memory: bool = False
Expand Down Expand Up @@ -44,9 +45,15 @@ def __call__(self, x: Tensor) -> Tensor:
def create_ct_dataset(
opt: CTDatasetOption, transform: Transform, is_train: bool
) -> Dataset:
slice_indexer: Callable[[Tensor], Tensor]
if len(opt.slice_index) == 0:
slice_indexer = BasicSliceIndexer(opt.threshold, opt.min_occupancy)
else:
def slice_indexer(_: Tensor) -> Tensor:
return tensor(opt.slice_index, dtype=int64)
return CT(
root=opt.root,
slice_indexer=BasicSliceIndexer(opt.threshold, opt.min_occupancy),
slice_indexer=slice_indexer,
transform=transform,
in_memory=opt.in_memory,
is_train=is_train,
Expand Down
5 changes: 3 additions & 2 deletions hrdae/models/networks/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class TCN1dOption(RNN1dOption):
@dataclass
class TCN2dOption(RNN2dOption):
num_layers: int = 3
image_size: tuple[int, int] = (64, 64)
image_size: list[int] = field(default_factory=lambda: [64, 64])
kernel_size: int = 3
dropout: float = 0.0

Expand All @@ -231,8 +231,9 @@ def create_tcn1d(latent_dim: int, opt: TCN1dOption) -> RNN1d:


def create_tcn2d(latent_dim: int, opt: TCN2dOption) -> RNN2d:
image_size = (opt.image_size[0], opt.image_size[1])
return TCN2d(
latent_dim, opt.num_layers, opt.image_size, opt.kernel_size, opt.dropout
latent_dim, opt.num_layers, image_size, opt.kernel_size, opt.dropout
)


Expand Down
Loading

0 comments on commit bb219d7

Please sign in to comment.