Skip to content

Commit

Permalink
resnet training with previous checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 9, 2024
1 parent d3eff4b commit e6ec792
Show file tree
Hide file tree
Showing 7 changed files with 580 additions and 59 deletions.
7 changes: 7 additions & 0 deletions data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
import torch
import scipy.io as sio
from torch.utils.data import Dataset, DataLoader
Expand Down Expand Up @@ -99,6 +100,12 @@ def test_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)

def predict_dataloader(self):
return DataLoader(self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)


def arguement_1(x):
Expand Down
69 changes: 40 additions & 29 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,31 @@

class JointReconstructionModule_V1(pl.LightningModule):

def __init__(self, model_name,log_dir="tb_logs"):
def __init__(self, model_name,log_dir="tb_logs", reconstruction_checkpoint=None):
super().__init__()

self.model_name = model_name
# TODO : use a real reconstruction module
self.reconstruction_model = model_generator(self.model_name, None)
""" if torch.cuda.is_available():
self.reconstruction_model = self.reconstruction_model.cuda()
else:
self.reconstruction_model.to('cpu') """
#self.reconstruction_model = EmptyModule()
self.reconstruction_model = model_generator(self.model_name, reconstruction_checkpoint)

self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, size_average=True)
#self.ssim_loss = SSIM(window_size=11, n_channels=28)
self.ssim_loss = SSIM(window_size=11, n_channels=3)

self.writer = SummaryWriter(log_dir)

def on_validation_start(self,stage=None):
print("---VALIDATION START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml"
self.shift_bool = False
if self.shift_bool:
self.crop_value_left = 8
self.crop_value_right = 8
self.crop_value_up = 8
self.crop_value_down = 8
else:
self.crop_value_left = 8
self.crop_value_right = 8
self.crop_value_up = 8
self.crop_value_down = 8

config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()

def on_predict_start(self,stage=None):
print("---PREDICT START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml"

config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
Expand All @@ -63,7 +57,7 @@ def _normalize_data_by_itself(self, data):
return normalized_data


def forward(self, x):
def forward(self, x, pattern=None):
print("---FORWARD---")

hyperspectral_cube, wavelengths = x
Expand All @@ -76,8 +70,13 @@ def forward(self, x):
# plt.show()
# print(f"batch size:{batch_size}")
# generate pattern
self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
self.pattern = self.pattern.to(self.device)

if pattern is None:
self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
self.pattern = self.pattern.to(self.device)
else:
self.pattern = pattern.to(self.device)
self.cassi_system.pattern = pattern.to(self.device)

# plt.imshow(pattern[0, :, :].cpu().detach().numpy())
# plt.show()
Expand Down Expand Up @@ -120,7 +119,7 @@ def forward(self, x):
def training_step(self, batch, batch_idx):
print("Training step")

loss,reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)


output_images = self._convert_output_to_images(self._normalize_image_tensor(self.acquired_image1))
Expand All @@ -146,6 +145,14 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
)

self.log_dict(
{ "train_ssim_loss": ssim_loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss}

def _normalize_image_tensor(self, tensor):
Expand All @@ -158,7 +165,7 @@ def _normalize_image_tensor(self, tensor):
def validation_step(self, batch, batch_idx):

print("Validation step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)

self.log_dict(
{ "val_loss": loss,
Expand All @@ -172,7 +179,7 @@ def validation_step(self, batch, batch_idx):

def test_step(self, batch, batch_idx):
print("Test step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)
self.log_dict(
{ "test_loss": loss,
},
Expand All @@ -184,8 +191,10 @@ def test_step(self, batch, batch_idx):

def predict_step(self, batch, batch_idx):
print("Predict step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)
print("Predict loss: ", loss.item())
print("Predict ssim loss: ", ssim_loss)
#self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

def _common_step(self, batch, batch_idx):
Expand All @@ -206,8 +215,10 @@ def _common_step(self, batch, batch_idx):


loss = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube))
ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2))
#ssim_loss = 0

return loss,reconstructed_cube, ref_cube
return loss, ssim_loss, reconstructed_cube, ref_cube

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
Expand Down
50 changes: 41 additions & 9 deletions optimization_modules_with_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ def __init__(self, model_name,log_dir="tb_logs",reconstruction_checkpoint=None):
super().__init__()

self.model_name = model_name
# TODO : use a real reconstruction module
self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=reconstruction_checkpoint)
self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=None)
if reconstruction_checkpoint is not None:
#self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=reconstruction_checkpoint)
self.reconstruction_model.load_state_dict(torch.load(reconstruction_checkpoint), strict=False)
self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1)

self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, size_average=True)
self.ssim_loss = SSIM(window_size=11, n_channels=28)

self.writer = SummaryWriter(log_dir)

Expand Down Expand Up @@ -82,6 +84,8 @@ def forward(self, x):
self.acquired_image1 = self.acquired_image1.flip(1)
self.acquired_image1 = self.acquired_image1.flip(2)
self.acquired_image1 = self.acquired_image1.unsqueeze(1).float()
#self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1)


self.pattern = self.mask_generation(self.acquired_image1).squeeze(1)
self.pattern = BinarizeFunction.apply(self.pattern)
Expand Down Expand Up @@ -120,7 +124,7 @@ def forward(self, x):
def training_step(self, batch, batch_idx):
print("Training step")

loss,reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)



Expand Down Expand Up @@ -151,6 +155,14 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
)

self.log_dict(
{ "train_ssim_loss": ssim_loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss}

def _normalize_image_tensor(self, tensor):
Expand All @@ -163,7 +175,7 @@ def _normalize_image_tensor(self, tensor):
def validation_step(self, batch, batch_idx):

print("Validation step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)

self.log_dict(
{ "val_loss": loss,
Expand All @@ -173,24 +185,43 @@ def validation_step(self, batch, batch_idx):
prog_bar=True,
)

self.log_dict(
{ "val_ssim_loss": ssim_loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss}

def test_step(self, batch, batch_idx):
print("Test step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
self.log_dict(
{ "test_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

self.log_dict(
{ "test_ssim_loss": ssim_loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)

return {"loss": loss}

def predict_step(self, batch, batch_idx):
print("Predict step")
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
print("Predict loss: ", loss.item())
print("Predict ssim loss: ", ssim_loss)
#self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

def _common_step(self, batch, batch_idx):
Expand Down Expand Up @@ -220,9 +251,10 @@ def _common_step(self, batch, batch_idx):
loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2)
loss = loss1 + loss2

ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2))
print(f"loss1 {loss1}")
print(f"loss2 {loss2}")
return loss,reconstructed_cube, ref_cube
return loss, ssim_loss, reconstructed_cube, ref_cube

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
Expand Down
Loading

0 comments on commit e6ec792

Please sign in to comment.