Skip to content

Commit

Permalink
using dd-cassi now
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 8, 2024
1 parent 5043f0c commit 633b1d8
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 115 deletions.
225 changes: 142 additions & 83 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from MST.simulation.train_code.architecture import *
from simca import load_yaml_config
import matplotlib.pyplot as plt
import torchvision
import numpy as np

from simca.functions_acquisition import *
from piqa import SSIM
from torch.utils.tensorboard import SummaryWriter
import io
import torchvision.transforms as transforms
from PIL import Image

class JointReconstructionModule_V1(pl.LightningModule):

def __init__(self, model_name):
def __init__(self, model_name,log_dir="tb_logs"):
super().__init__()

# TODO : use a real reconstruction module
Expand All @@ -21,10 +27,13 @@ def __init__(self, model_name):
self.reconstruction_model.to('cpu') """
#self.reconstruction_model = EmptyModule()
self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, size_average=True)

self.writer = SummaryWriter(log_dir)

def on_validation_start(self,stage=None):
print("---VALIDATION START---")
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml"
self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml"
self.shift_bool = False
if self.shift_bool:
self.crop_value_left = 8
Expand All @@ -41,76 +50,75 @@ def on_validation_start(self,stage=None):
self.cassi_system = CassiSystemOptim(system_config=config_system)
self.cassi_system.propagate_coded_aperture_grid()

def _normalize_data_by_itself(self, data):
# Calculate the mean and std for each batch individually
# Keep dimensions for broadcasting
mean = torch.mean(data, dim=[1, 2], keepdim=True)
std = torch.std(data, dim=[1, 2], keepdim=True)

# Normalize each batch by its mean and std
normalized_data = (data - mean) / std
return normalized_data


def forward(self, x):
print("---FORWARD---")

hyperspectral_cube, wavelengths = x
hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device)
batch_size, H, W, C = hyperspectral_cube.shape
fig, ax = plt.subplots(1, 1)
plt.title(f"entry cube")
ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy())
plt.show()

# fig, ax = plt.subplots(1, 1)
# plt.title(f"entry cube")
# ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy())
# plt.show()
# print(f"batch size:{batch_size}")
# generate pattern
pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
pattern = pattern.to(self.device)
self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size)
self.pattern = self.pattern.to(self.device)

plt.imshow(pattern[0, :, :].cpu().detach().numpy())
plt.show()
# plt.imshow(pattern[0, :, :].cpu().detach().numpy())
# plt.show()

# print(f"pattern_size: {pattern.shape}")

# generate first acquisition with simca
acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, pattern, wavelengths)
filtering_cubes = subsample(self.cassi_system.filtering_cube, np.linspace(450, 650, self.cassi_system.filtering_cube.shape[-1]), np.linspace(450, 650, 28))
filtering_cubes = filtering_cubes.permute(0, 3, 1, 2).float().to(self.device)
filtering_cubes = torch.flip(filtering_cubes, dims=(2,3)) # -1 magnification
displacement_in_pix = self.cassi_system.get_displacement_in_pixels(dataset_wavelengths=wavelengths)
#print("displacement_in_pix", displacement_in_pix)

# # vizualize first image acquisition
# plt.imshow(acquired_image1[0, :, :].cpu().detach().numpy())
# plt.show()

# process first acquisition with reconstruction model
# TODO : replace by the real reconstruction model
if not self.shift_bool:
acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C
acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification
fig, ax = plt.subplots(1, 2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(acquired_cubes[0, 0, :, :].cpu().detach().numpy())
plt.show()

reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)
else:
shifted_image = self.shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device)
mask_3d = expand_mask_3d(self.cassi_system.pattern_crop.flip(dims=(1, 2))[:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right]).float().to(self.device)
filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device)
self.acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device)


fig,ax = plt.subplots(1,2)
plt.title(f"true cube cropped vs measurement")
ax[0].imshow(hyperspectral_cube[0, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy())
ax[1].imshow(shifted_image[0, 0, :, :].cpu().detach().numpy())
plt.show()
self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1)
acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C

reconstructed_cube = self.reconstruction_model(shifted_image, mask_3d)
filtering_cubes = subsample(filtering_cube, np.linspace(450, 650, filtering_cube.shape[-1]), np.linspace(450, 650, 28)).permute((0, 3, 1, 2))


reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes.to(self.device))

#print(acquired_cubes.shape)
#print(filtering_cubes.shape)

#

return reconstructed_cube


def training_step(self, batch, batch_idx):
print("Training step")

loss, y_hat, y = self._common_step(batch, batch_idx)
loss,reconstructed_cube, ref_cube = self._common_step(batch, batch_idx)

hyperspectral_cube, wavelengths = batch
hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device)

output_images = self._convert_output_to_images(self.acquired_image1)
patterns = self._convert_output_to_images(self.pattern)
input_images = self._convert_output_to_images(hyperspectral_cube[:,:,:,0])


self._log_images('train/output_images', output_images, self.global_step)
self._log_images('train/input_images', input_images, self.global_step)
self._log_images('train/patterns', patterns, self.global_step)

spectral_filter_plot = self.plot_spectral_filter(ref_cube,reconstructed_cube)
self.writer.add_image('Spectral Filter', spectral_filter_plot, self.global_step)

self.log_dict(
{ "train_loss": loss,
},
Expand All @@ -119,12 +127,19 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
)

return {"loss": loss, "scores":y_hat, "y":y}
return {"loss": loss}

def _normalize_image_tensor(self, tensor):
# Normalize the tensor to the range [0, 1]
min_val = tensor.min()
max_val = tensor.max()
normalized_tensor = (tensor - min_val) / (max_val - min_val)
return normalized_tensor

def validation_step(self, batch, batch_idx):

print("Validation step")
loss, y_hat, y = self._common_step(batch, batch_idx)
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)

self.log_dict(
{ "val_loss": loss,
Expand All @@ -134,67 +149,111 @@ def validation_step(self, batch, batch_idx):
prog_bar=True,
)

return {"loss": loss, "scores":y_hat, "y":y}
return {"loss": loss}

def test_step(self, batch, batch_idx):
print("Test step")
loss, y_hat, y = self._common_step(batch, batch_idx)
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
self.log_dict(
{ "test_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
)
return {"loss": loss, "scores":y_hat, "y":y}
return {"loss": loss}

def predict_step(self, batch, batch_idx):
print("Predict step")
loss, _, _ = self._common_step(batch, batch_idx)
loss,reconstructed_cube, ref_cube= self._common_step(batch, batch_idx)
self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

def _common_step(self, batch, batch_idx):

y_hat = self.forward(batch)

reconstructed_cube = self.forward(batch)
hyperspectral_cube, wavelengths = batch
#hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1)
hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_up:-self.crop_value_down, self.crop_value_left:-self.crop_value_right]

fig, ax = plt.subplots(1, 2)
plt.title(f"true cube vs reconstructed cube")
ax[0].imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy())
ax[1].imshow(y_hat[0, 0, :, :].cpu().detach().numpy())
plt.show()
hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device)
reconstructed_cube = reconstructed_cube.permute(0, 2, 3, 1).to(self.device)
ref_cube = match_dataset_to_instrument(hyperspectral_cube, reconstructed_cube[0, :, :,0])

# fig, ax = plt.subplots(1, 2)
# plt.title(f"true cube vs reconstructed cube")
# ax[0].imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy())
# ax[1].imshow(reconstructed_cube[0, :, :, 0].cpu().detach().numpy())
# plt.show()

#print("y_hat shape", y_hat.shape)
#print("hyperspectral_cube shape", hyperspectral_cube.shape)

loss = torch.sqrt(self.loss_fn(y_hat, hyperspectral_cube))
loss = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube))

return loss, y_hat, hyperspectral_cube
return loss,reconstructed_cube, ref_cube

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
return optimizer

def shift_back(self, inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256]
[bs, row, col] = inputs.shape
nC = 28
d = d[0]
d -= d.min()
self.crop_value_right = 8+int(np.round(d.max()))
output = torch.zeros(bs, nC, row, col - int(np.round(d.max()))).float().to(self.device)
for i in range(nC):
shift = int(np.round(d[i]))
#output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2
# if shift >=0:
# output[:, i, :, :] = inputs[:, :, shift:row+shift]
# else:
# output[:, i, :, :] = inputs[:, :, shift-row:shift]
output[:, i, :, :] = inputs[:, :, shift:shift + col - int(np.round(d.max()))]
return output

def _log_images(self, tag, images, global_step):
# Convert model output to image grid and log to TensorBoard
img_grid = torchvision.utils.make_grid(images)
self.writer.add_image(tag, img_grid, global_step)

def _convert_output_to_images(self, acquired_images):

acquired_images = acquired_images.unsqueeze(1)

# Create a grid of images for visualization
img_grid = torchvision.utils.make_grid(acquired_images)
return img_grid

def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_cube):


batch_size, y,x, lmabda_ = ref_hyperspectral_cube.shape

# Create a figure with subplots arranged horizontally
fig, axs = plt.subplots(1, batch_size, figsize=(batch_size * 5, 4)) # Adjust figure size as needed

# Check if batch_size is 1, axs might not be iterable
if batch_size == 1:
axs = [axs]

# Plot each spectral filter in its own subplot
for i in range(batch_size):
colors = ['b', 'g', 'r']
for j in range(3):
pix_j_row_value = np.random.randint(0,y)
pix_j_col_value = np.random.randint(0,x)

pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
pixe_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
axs[i].plot(pixe_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j])
axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j])

axs[i].set_title(f"Reconstruction quality")

axs[i].set_xlabel("Wavelength index")
axs[i].set_ylabel("pxie values")
axs[i].grid(True)

plt.legend()
# Adjust layout
plt.tight_layout()

# Create a buffer to save plot
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)

# Convert PNG buffer to PIL Image
image = Image.open(buf)

# Convert PIL Image to Tensor
image_tensor = transforms.ToTensor()(image)
return image_tensor


def subsample(input, origin_sampling, target_sampling):
[bs, row, col, nC] = input.shape
Expand Down
25 changes: 6 additions & 19 deletions simca/CassiSystem_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def generate_filtering_cube(self):
numpy.ndarray: filtering cube generated according to the optical system & the pattern configuration (R x C x W)
"""
self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern,
self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern.unsqueeze(-1).repeat(1, 1, 1, self.wavelengths.shape[0]),
X_init=self.X_coordinates_propagated_coded_aperture,
Y_init=self.Y_coordinates_propagated_coded_aperture,
X_target=self.X_detector_coordinates_grid,
Expand Down Expand Up @@ -193,24 +193,8 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
"""
self.wavelengths= self.wavelengths.to(self.device)



#
# plt.plot(hyperspectral_cube[0,0,0,:].cpu().numpy())
# plt.title("Original spectrum")
# plt.show()
# print("cube shape: ", hyperspectral_cube.shape)
# print("wavelengths shape: ", wavelengths.shape)
# print("self.wavelengths shape: ", self.wavelengths.shape)
dataset = self.interpolate_dataset_along_wavelengths_torch(hyperspectral_cube, wavelengths,self.wavelengths, chunck_size)


# plt.plot(dataset[0,0,0,:].cpu().numpy())
# plt.title("Interpolated spectrum")
# plt.show()



if dataset is None:
return None

Expand All @@ -226,8 +210,10 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
except:
return print("Please generate filtering cube first")

scene = match_dataset_to_instrument(dataset, self.filtering_cube)
scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device)

scene = match_dataset_to_instrument(dataset, self.filtering_cube[0,:,:,0])

# scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)).to(self.device) if isinstance(scene, np.ndarray) else scene.to(self.device)

measurement_in_3D = generate_dd_measurement_torch(scene, self.filtering_cube, chunck_size)

Expand All @@ -250,6 +236,7 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals
# print("dataset shape: ", dataset.shape)
# print("X coded shape: ", X_coded_aper_coordinates_crop.shape)


scene = match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop)

pattern_crop = crop_center_3D(pattern, scene.shape[2], scene.shape[1]).to(self.device)
Expand Down
Loading

0 comments on commit 633b1d8

Please sign in to comment.