Skip to content

Commit

Permalink
update training structure
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent d612fe0 commit 1df966d
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 28 deletions.
11 changes: 9 additions & 2 deletions hrdae/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from datetime import datetime

import hydra
from omegaconf import DictConfig, OmegaConf

from .dataloaders import create_dataloader
from .models import create_model
from .option import Option, TrainExpOption, process_options, save_options
from .option import Option, TrainExpOption, save_options


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
opt: Option = OmegaConf.to_object(cfg) # type: ignore
opt = process_options(opt)
opt.experiment.result_dir = (
opt.experiment.result_dir
/ opt.experiment.dataloader.__class__.__name__
/ opt.experiment.model.__class__.__name__
/ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
save_options(opt, opt.experiment.result_dir)
if (
isinstance(opt.experiment, TrainExpOption)
Expand Down
1 change: 0 additions & 1 deletion hrdae/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
task_name: mmnist-hrdae2d-guided1d
defaults:
- config_schema
- experiment: train
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/dataloader/dataset/moving_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
root: data
slice_index: [32]
slice_index: [8, 16, 24, 32, 40, 48, 56]
defaults:
- /config/experiment/dataloader/dataset/moving_mnist@_here_
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/model/network/hrdae2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ activation: sigmoid
debug_show_dim: false
defaults:
- /config/experiment/model/network/hrdae2d@_here_
- motion_encoder: guided1d
- motion_encoder: rnn1d
4 changes: 2 additions & 2 deletions hrdae/conf/experiment/model/network/motion_encoder/rnn1d.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
in_channels: 1
in_channels: 7
conv_params:
- kernel_size: [3]
stride: [2]
Expand All @@ -14,4 +14,4 @@ conv_params:
output_padding: [1]
defaults:
- /config/experiment/model/network/motion_encoder/rnn1d@_here_
- rnn: tcn1d
- rnn: conv_lstm1d
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/model/scheduler/onecyclelr.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
max_lr: 0.001
max_lr: 0.1
defaults:
- /config/experiment/model/scheduler/onecyclelr@_here_
6 changes: 2 additions & 4 deletions hrdae/conf/experiment/model/vr.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
loss_coef:
wmse: 0.5
pjc2d: 0.5
wmse: 1
phase: "0"
pred_diff: false
pred_diff: true
defaults:
- /config/experiment/model/vr@_here_
- [email protected]: wmse
- [email protected]: pjc2d
- network: hrdae2d
- optimizer: adam
- scheduler: onecyclelr
1 change: 0 additions & 1 deletion hrdae/conf/experiment/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
run_name: 2024-06-16-0158
result_dir: results
defaults:
- /config/experiment/test@_here_
Expand Down
3 changes: 1 addition & 2 deletions hrdae/conf/experiment/train.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
run_name: 2024-06-16-0158
result_dir: results
n_epoch: 30
n_epoch: 100
debug: false
defaults:
- /config/experiment/train@_here_
Expand Down
9 changes: 8 additions & 1 deletion hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ def train(
n_epoch: int,
result_dir: Path,
debug: bool,
) -> None:
) -> float:
max_iter = None
if debug:
max_iter = 5

self.network.to(self.device)

least_val_loss = float("inf")

for epoch in range(n_epoch):
self.network.train()
running_loss = 0.0
Expand Down Expand Up @@ -105,6 +107,9 @@ def train(
avg_val_loss = total_val_loss / len(val_loader)
print(f"Epoch: {epoch+1}, Val Loss: {avg_val_loss:.6f}")

if avg_val_loss < least_val_loss:
least_val_loss = avg_val_loss

if epoch % 10 == 0:
data = next(iter(val_loader))

Expand All @@ -120,6 +125,8 @@ def train(
result_dir / "logs" / "reconstructed",
)

return least_val_loss


def create_basic_model(
opt: BasicModelOption,
Expand Down
2 changes: 1 addition & 1 deletion hrdae/models/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ def train(
n_epoch: int,
result_dir: Path,
debug: bool,
) -> None:
) -> float:
pass
9 changes: 8 additions & 1 deletion hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ def train(
n_epoch: int,
result_dir: Path,
debug: bool,
) -> None:
) -> float:
max_iter = None
if debug:
max_iter = 5

self.network.to(self.device)

least_val_loss = float("inf")

for epoch in range(n_epoch):
self.network.train()
running_loss = 0.0
Expand Down Expand Up @@ -126,6 +128,9 @@ def train(
avg_val_loss = total_val_loss / len(val_loader)
print(f"Epoch: {epoch+1}, Val Loss: {avg_val_loss:.6f}")

if avg_val_loss < least_val_loss:
least_val_loss = avg_val_loss

if epoch % 10 == 0:
data = next(iter(val_loader))

Expand All @@ -147,6 +152,8 @@ def train(
result_dir / "logs" / "reconstructed",
)

return least_val_loss


def create_vr_model(
opt: VRModelOption,
Expand Down
10 changes: 0 additions & 10 deletions hrdae/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

@dataclass
class ExpOption:
run_name: str = MISSING
result_dir: Path = field(default_factory=Path)
debug: bool = False

Expand All @@ -29,18 +28,9 @@ class TestExpOption(ExpOption):

@dataclass
class Option:
task_name: str = MISSING

experiment: ExpOption = MISSING


def process_options(opt: Option) -> Option:
opt.experiment.result_dir = (
opt.experiment.result_dir / opt.task_name / opt.experiment.run_name
)
return opt


def save_options(opt: Option, save_dir: Path) -> None:
config = OmegaConf.structured(opt)
save_dir.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit 1df966d

Please sign in to comment.