Skip to content

Commit

Permalink
fixing conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 8, 2024
2 parents 88ca7a4 + 057d66d commit 074a221
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
17 changes: 9 additions & 8 deletions optimization_modules.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, model_name,log_dir="tb_logs"):
else:
self.reconstruction_model.to('cpu') """
#self.reconstruction_model = EmptyModule()

self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, size_average=True)

Expand Down Expand Up @@ -92,7 +93,7 @@ def forward(self, x):
# self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1)
# acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C

filtering_cubes = subsample(filtering_cube, np.linspace(450, 650, filtering_cube.shape[-1]), np.linspace(450, 650, 28)).permute((0, 3, 1, 2))
filtering_cubes = subsample(filtering_cube, torch.linspace(450, 650, filtering_cube.shape[-1]), torch.linspace(450, 650, 28)).permute((0, 3, 1, 2)).float().to(self.device)

if self.model_name == "birnat":
# acquisition = self.acquired_image1.unsqueeze(1)
Expand Down Expand Up @@ -234,14 +235,14 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_
pix_j_col_value = np.random.randint(0,x)

pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
pixe_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
axs[i].plot(pixe_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j])
pix_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
axs[i].plot(pix_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j])
axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j])

axs[i].set_title(f"Reconstruction quality")

axs[i].set_xlabel("Wavelength index")
axs[i].set_ylabel("pxie values")
axs[i].set_ylabel("pix values")
axs[i].grid(True)

plt.legend()
Expand All @@ -264,12 +265,12 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_

def subsample(input, origin_sampling, target_sampling):
[bs, row, col, nC] = input.shape
output = torch.zeros(bs, row, col, len(target_sampling))
indices = torch.zeros(len(target_sampling), dtype=torch.int)
for i in range(len(target_sampling)):
sample = target_sampling[i]
idx = np.abs(origin_sampling-sample).argmin()
output[:,:,:,i] = input[:,:,:,idx]
return output
idx = torch.abs(origin_sampling-sample).argmin()
indices[i] = idx
return input[:,:,:,indices]

def expand_mask_3d(mask_batch):
if len(mask_batch.shape)==3:
Expand Down
22 changes: 15 additions & 7 deletions training_simca_reconstruction.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from optimization_modules import JointReconstructionModule_V1
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torch


# data_dir = "./datasets_reconstruction/"
data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28"

datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=3)
data_dir = "./datasets_reconstruction/cave_1024_28"
datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1)


name = "testing_simca_reconstruction"
model_name = "birnat"
Expand All @@ -35,9 +36,16 @@

reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name)

trainer = pl.Trainer( logger=logger,
accelerator="cpu",
max_epochs=500,
log_every_n_steps=1)

if torch.cuda.is_available():
trainer = pl.Trainer( logger=logger,
accelerator="gpu",
max_epochs=500,
log_every_n_steps=1)
else:
trainer = pl.Trainer( logger=logger,
accelerator="cpu",
max_epochs=500,
log_every_n_steps=1)

trainer.fit(reconstruction_module, datamodule)

0 comments on commit 074a221

Please sign in to comment.