Skip to content

Commit

Permalink
permutation fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 8, 2024
1 parent 1b4d259 commit fff0471
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 37 deletions.
78 changes: 46 additions & 32 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ def __init__(self, model_name):
def on_validation_start(self,stage=None):
print("---VALIDATION START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
# self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml"
self.shift_bool = True
if self.shift_bool:
self.crop_value_left = 8
self.crop_value_right = 8
else:
self.crop_value_left = 8
self.crop_value_right = 8
config_system = load_yaml_config(self.config)
self.config_patterns = load_yaml_config("simca/configs/pattern.yml")
self.cassi_system = CassiSystemOptim(system_config=config_system)
Expand All @@ -35,8 +41,12 @@ def forward(self, x):
print("---FORWARD---")

hyperspectral_cube, wavelengths = x
hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1).to(self.device)
hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device)
batch_size, H, W, C = hyperspectral_cube.shape
fig, ax = plt.subplots(1, 1)
plt.title(f"entry cube")
ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy())
plt.show()
# print(f"batch size:{batch_size}")
# generate pattern
pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
Expand All @@ -61,23 +71,25 @@ def forward(self, x):

# process first acquisition with reconstruction model
# TODO : replace by the real reconstruction model
if self.config == "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml":
if not self.shift_bool:
acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C
acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification
fig, ax = plt.subplots(1, 2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(acquired_cubes[0, 0, :, :].cpu().detach().numpy())
plt.show()

reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)
else:
mask_3d = expand_mask_3d(pattern.flip(dims=(1, 2))).float().to(self.device)
shifted_image = shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device)


for i in range(10):
shifted_image = self.shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device)

fig,ax = plt.subplots(1,2)
# plt.title(f"acquired_cubes {i}")
ax[0].imshow(shifted_image[0, i, :, :].cpu().detach().numpy())
ax[1].imshow(filtering_cubes[0, i, :, :].cpu().detach().numpy())
plt.show()
fig,ax = plt.subplots(1,2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, 8:-8, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(shifted_image[0, 0, :, :].cpu().detach().numpy())
plt.show()

reconstructed_cube = self.reconstruction_model(shifted_image, mask_3d)

Expand Down Expand Up @@ -144,11 +156,12 @@ def _common_step(self, batch, batch_idx):

hyperspectral_cube, wavelengths = batch
#hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
hyperspectral_cube = hyperspectral_cube[:,:, 8:-8, 8:-8]
hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right]

fig, ax = plt.subplots(1, 1)
# plt.title(f"acquired_cubes {i}")
ax.imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy())
fig, ax = plt.subplots(1, 2)
plt.title(f"cubes")
ax[0].imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy())
ax[1].imshow(y_hat[0, 0, :, :].cpu().detach().numpy())
plt.show()

#print("y_hat shape", y_hat.shape)
Expand All @@ -161,6 +174,23 @@ def _common_step(self, batch, batch_idx):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
return optimizer

def shift_back(self, inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256]
[bs, row, col] = inputs.shape
nC = 28
d = d[0]
d -= d.min()
self.crop_value_right += int(np.round(d.max()))
output = torch.zeros(bs, nC, row, col - int(np.round(d.max()))).float().to(self.device)
for i in range(nC):
shift = int(np.round(d[i]))
#output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2
# if shift >=0:
# output[:, i, :, :] = inputs[:, :, shift:row+shift]
# else:
# output[:, i, :, :] = inputs[:, :, shift-row:shift]
output[:, i, :, :] = inputs[:, :, shift:shift + col - int(np.round(d.max()))]
return output

def subsample(input, origin_sampling, target_sampling):
[bs, row, col, nC] = input.shape
Expand All @@ -178,22 +208,6 @@ def expand_mask_3d(mask_batch):
mask3d = mask_batch.repeat((1, 1, 1, 28))
mask3d = torch.permute(mask3d, (0, 3, 1, 2))
return mask3d

def shift_back(inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256]
[bs, row, col] = inputs.shape
nC = 28
output = torch.zeros(bs, nC, row, row).cuda().float()
d = d[0]
d -= d.min()
for i in range(nC):
shift = int(np.round(d[i]))
#output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2
# if shift >=0:
# output[:, i, :, :] = inputs[:, :, shift:row+shift]
# else:
# output[:, i, :, :] = inputs[:, :, shift-row:shift]
output[:, i, :, :] = inputs[:, :, shift:shift + row]
return output

class EmptyModule(nn.Module):
def __init__(self):
Expand Down
4 changes: 0 additions & 4 deletions simca/CassiSystem_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
self.last_filtered_interpolated_scene = sd_measurement
self.interpolated_scene = scene

for i in range(5):
plt.imshow(sd_measurement[0,:,:,i*8].cpu().numpy())
plt.title("SD measurement")
plt.show()

if dataset_labels is not None:
scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.last_filtered_interpolated_scene))
Expand Down
2 changes: 1 addition & 1 deletion training_simca_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
#data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28"

datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=11)
datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=5)

name = "testing_simca_reconstruction"

Expand Down

0 comments on commit fff0471

Please sign in to comment.