Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 8, 2024
1 parent fff0471 commit 5043f0c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
18 changes: 11 additions & 7 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ def __init__(self, model_name):
def on_validation_start(self,stage=None):
print("---VALIDATION START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
self.shift_bool = True
self.shift_bool = False
if self.shift_bool:
self.crop_value_left = 8
self.crop_value_right = 8
self.crop_value_up = 8
self.crop_value_down = 8
else:
self.crop_value_left = 8
self.crop_value_right = 8
self.crop_value_up = 8
self.crop_value_down = 8
config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
Expand Down Expand Up @@ -76,18 +80,18 @@ def forward(self, x):
acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification
fig, ax = plt.subplots(1, 2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(acquired_cubes[0, 0, :, :].cpu().detach().numpy())
plt.show()

reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)
else:
mask_3d = expand_mask_3d(pattern.flip(dims=(1, 2))).float().to(self.device)
shifted_image = self.shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device)
mask_3d = expand_mask_3d(self.cassi_system.pattern_crop.flip(dims=(1, 2))[:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right]).float().to(self.device)

fig,ax = plt.subplots(1,2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, 8:-8, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(shifted_image[0, 0, :, :].cpu().detach().numpy())
plt.show()

Expand Down Expand Up @@ -156,10 +160,10 @@ def _common_step(self, batch, batch_idx):

hyperspectral_cube, wavelengths = batch
#hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right]
hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right]

fig, ax = plt.subplots(1, 2)
plt.title(f"cubes")
plt.title(f"true cube vs reconstructed cube")
ax[0].imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy())
ax[1].imshow(y_hat[0, 0, :, :].cpu().detach().numpy())
plt.show()
Expand All @@ -180,7 +184,7 @@ def shift_back(self, inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28
nC = 28
d = d[0]
d -= d.min()
self.crop_value_right += int(np.round(d.max()))
self.crop_value_right = 8+int(np.round(d.max()))
output = torch.zeros(bs, nC, row, col - int(np.round(d.max()))).float().to(self.device)
for i in range(nC):
shift = int(np.round(d[i]))
Expand Down
7 changes: 1 addition & 6 deletions simca/CassiSystem_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,9 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals

pattern_crop = crop_center_3D(pattern, scene.shape[2], scene.shape[1]).to(self.device)

self.pattern_crop_dataset = interpolate_data_on_grid_positions_torch(pattern_crop.unsqueeze(-1),
self.X_coordinates_propagated_coded_aperture,
self.Y_coordinates_propagated_coded_aperture,
self.X_detector_coordinates_grid,
self.Y_detector_coordinates_grid)
self.pattern_crop = pattern_crop

pattern_crop = pattern_crop.unsqueeze(-1).repeat(1, 1, 1, scene.size(-1))


#print(scene.get_device())
#print(pattern_crop.get_device())
Expand Down

0 comments on commit 5043f0c

Please sign in to comment.