Skip to content

Commit

Permalink
adding training with resnet and reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 9, 2024
1 parent b33bb7c commit 35fc173
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions optimization_modules_with_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@ def _common_step(self, batch, batch_idx):

loss1 = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube))
loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2)
loss = loss1 + loss2
loss = loss1

ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2))
print(f"loss1 {loss1}")
print(f"loss2 {loss2}")
return loss, ssim_loss, reconstructed_cube, ref_cube

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return { "optimizer":optimizer,
"lr_scheduler":{
"scheduler":torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6),
Expand Down Expand Up @@ -382,4 +382,4 @@ def forward(ctx, input):
@staticmethod
def backward(ctx, grad_output):
# For backward pass, just pass the gradients through unchanged
return grad_output
return grad_output
13 changes: 8 additions & 5 deletions training_simca_reconstruction_with_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@

# data_dir = "./datasets_reconstruction/"
data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28"
data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
# data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5)

datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M')

name = "testing_simca_reconstruction_full"
model_name = "dauhst_9"
reconstruction_checkpoint = "/home/lpaillet/Documents/simca/tb_logs/testing_simca_reconstruction/version_24/checkpoints/epoch=499-step=18000.ckpt"
reconstruction_checkpoint = "./checkpoints/epoch=499-step=18000.ckpt"
resnet_checkpoint = None

log_dir = 'tb_logs'

train = True
retrain_recons = False
retrain_recons = True

logger = TensorBoardLogger(log_dir, name=name)

Expand All @@ -49,12 +49,15 @@
sub_module = JointReconstructionModule_V1(model_name, log_dir)
sub_module.load_state_dict(checkpoint["state_dict"])


resnet_checkpoint = "./checkpoints/best-checkpoint_resnet_only_24-03-09_18h05.ckpt"

if not retrain_recons or not train:
sub_module.eval()

reconstruction_module = JointReconstructionModule_V3(sub_module,
log_dir=log_dir+'/'+ name,
reconstruction_checkpoint = reconstruction_checkpoint)
resnet_checkpoint=resnet_checkpoint)


if torch.cuda.is_available():
Expand All @@ -73,4 +76,4 @@
if train:
trainer.fit(reconstruction_module, datamodule)
else:
trainer.predict(reconstruction_module, datamodule, ckpt_path=resnet_checkpoint)
trainer.predict(reconstruction_module, datamodule, ckpt_path=resnet_checkpoint)

0 comments on commit 35fc173

Please sign in to comment.