Skip to content

Commit

Permalink
update functions for working resnet optim
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 9, 2024
1 parent 71cf609 commit 660c7ad
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion simca/CassiSystem_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def generate_filtering_cube(self):
numpy.ndarray: filtering cube generated according to the optical system & the pattern configuration (R x C x W)
"""
self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern.unsqueeze(-1).repeat(1, 1, 1, self.wavelengths.shape[0]),

self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern.unsqueeze(-1).repeat(1, 1, 1, self.wavelengths.shape[0]).to(self.device),
X_init=self.X_coordinates_propagated_coded_aperture,
Y_init=self.Y_coordinates_propagated_coded_aperture,
X_target=self.X_detector_coordinates_grid,
Expand Down
5 changes: 5 additions & 0 deletions simca/functions_acquisition_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def interpolate_data_on_grid_positions_torch(data, X_init, Y_init, X_target, Y_t

#print(data.shape)


# Ensure tensors are on the correct device and data type
device = data.device
dtype = torch.float64 # Using double precision for interpolation calculations

# Ensure tensors are on the correct device and data type
device = data.device
dtype = torch.float64 # Using double precision for interpolation calculations
Expand Down

0 comments on commit 660c7ad

Please sign in to comment.