Skip to content

Commit

Permalink
adding plots and imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 8, 2024
1 parent a4bee18 commit 1b4d259
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, model_name):
def on_validation_start(self,stage=None):
print("---VALIDATION START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
#self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml"
# self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml"
config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
Expand All @@ -41,30 +41,44 @@ def forward(self, x):
# generate pattern
pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
pattern = pattern.to(self.device)

plt.imshow(pattern[0, :, :].cpu().detach().numpy())
plt.show()

# print(f"pattern_size: {pattern.shape}")

# generate first acquisition with simca

acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths)
filtering_cubes = subsample(self.cassi_system.filtering_cube, np.linspace(450, 650, self.cassi_system.filtering_cube.shape[-1]), np.linspace(450, 650, 28))
filtering_cubes = filtering_cubes.permute(0, 3, 1, 2).float().to(self.device)
filtering_cubes = torch.flip(filtering_cubes, dims=(2,3)) # -1 magnification
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
#print("displacement_in_pix", displacement_in_pix)

# vizualize first image acquisition
plt.imshow(acquired_image1[0, :, :].cpu().detach().numpy())
plt.show()
# # vizualize first image acquisition
# plt.imshow(acquired_image1[0, :, :].cpu().detach().numpy())
# plt.show()

# process first acquisition with reconstruction model
# TODO : replace by the real reconstruction model
if self.config == "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml":
acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C
acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification

reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)
else:
mask_3d = expand_mask_3d(pattern).float().to(self.device)
shifted_image = shift_back(acquired_image1, displacement_in_pix).float().to(self.device)
mask_3d = expand_mask_3d(pattern.flip(dims=(1, 2))).float().to(self.device)
shifted_image = shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device)


for i in range(10):

fig,ax = plt.subplots(1,2)
# plt.title(f"acquired_cubes {i}")
ax[0].imshow(shifted_image[0, i, :, :].cpu().detach().numpy())
ax[1].imshow(filtering_cubes[0, i, :, :].cpu().detach().numpy())
plt.show()

reconstructed_cube = self.reconstruction_model(shifted_image, mask_3d)


Expand Down Expand Up @@ -132,6 +146,11 @@ def _common_step(self, batch, batch_idx):
#hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
hyperspectral_cube = hyperspectral_cube[:,:, 8:-8, 8:-8]

fig, ax = plt.subplots(1, 1)
# plt.title(f"acquired_cubes {i}")
ax.imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy())
plt.show()

#print("y_hat shape", y_hat.shape)
#print("hyperspectral_cube shape", hyperspectral_cube.shape)

Expand Down

0 comments on commit 1b4d259

Please sign in to comment.