Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML4Science 2023: Trainable Masks #110

Open
wants to merge 132 commits into
base: main
Choose a base branch
from
Open
Changes from 6 commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
72d0f58
add psf.tiff file
aelalamy42 Nov 17, 2023
96e91b4
Merge branch 'LCAV:main' into main
aelalamy42 Nov 18, 2023
f639321
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Nov 20, 2023
e68531c
First test for using torch
Ghita2002 Nov 21, 2023
521bba1
Merge branch 'main' of https://github.com/aelalamy42/LenslessPiCam_ML…
Ghita2002 Nov 21, 2023
e0d76bd
Merge branch 'main' of https://github.com/aelalamy42/LenslessPiCam_ML…
Nov 21, 2023
c2646a5
create_mask into torch
Ghita2002 Nov 21, 2023
331e327
phase_retrieval in torch
aelalamy42 Nov 21, 2023
c229522
Merge branch 'Work_Ghita' of https://github.com/aelalamy42/LenslessPi…
Ghita2002 Nov 21, 2023
5e12425
added the height varying mask
Seif-Wessam Nov 23, 2023
356d6ba
MultiLensArray initialization
Ghita2002 Nov 25, 2023
388daf1
Merge branch 'Work_Ghita' of https://github.com/aelalamy42/LenslessPi…
Ghita2002 Nov 25, 2023
515fd11
Adding comments to mask.py
Ghita2002 Nov 25, 2023
0f2d9d1
first half of the MLA
aelalamy42 Nov 25, 2023
eeb9399
implement overlapping check
aelalamy42 Nov 25, 2023
9ce0426
Multilens kind of done and tested
aelalamy42 Nov 26, 2023
b8151a9
Merge pull request #1 from aelalamy42/Work_Ahmed
aelalamy42 Nov 27, 2023
12c63f2
height varying tested
aelalamy42 Nov 27, 2023
7b2a1a4
Update mask.py
aelalamy42 Nov 27, 2023
1dc82fc
Multilens patched
aelalamy42 Nov 27, 2023
2d4d5fc
Merge pull request #2 from aelalamy42/aelalamy42-patch-1
aelalamy42 Nov 27, 2023
7b85737
Merge branch 'main' into travail_seif
aelalamy42 Nov 27, 2023
160c802
Merge pull request #3 from aelalamy42/travail_seif
aelalamy42 Nov 27, 2023
0951bd9
added the torch implementation (as a 2nd option) to the HeightVarying…
Seif-Wessam Nov 27, 2023
228ca0d
set is_torch to TRUE for the height_varying mask
Seif-Wessam Nov 28, 2023
af1818c
added the test for the torch case in test_masks.py
Seif-Wessam Nov 28, 2023
d882653
Merge branch 'main' into Work_Ghita
Ghita2002 Nov 29, 2023
2bd322c
from numpy to torch
Ghita2002 Nov 29, 2023
51e6966
edits to the pytorch implementation
Seif-Wessam Nov 30, 2023
9254864
added the device to HeightVarying
Seif-Wessam Nov 30, 2023
2a92bbb
added the device
Seif-Wessam Nov 30, 2023
7209fee
fix phase of multilens
aelalamy42 Nov 30, 2023
613dc07
added is_Torch to the constructor and to compute_psf
Seif-Wessam Nov 30, 2023
af26708
changed the compute_psf function to include is_torch
Seif-Wessam Nov 30, 2023
def924c
some tests
Ghita2002 Nov 30, 2023
d2684bc
fixed the heightVarying with torch
Seif-Wessam Nov 30, 2023
26acdd1
changed the testing code to have is_Torch=True
Seif-Wessam Nov 30, 2023
de51f5c
corrected the code for testing the torch part of HeightVaryingMask
Seif-Wessam Dec 1, 2023
8c4b49b
improvement to the test.
Seif-Wessam Dec 2, 2023
2cfaa7f
fix phase of multilens
aelalamy42 Dec 4, 2023
fffdc77
merge main
aelalamy42 Dec 4, 2023
394201b
torch implementation of height varying
aelalamy42 Dec 4, 2023
b5aa4c0
Merge pull request #4 from aelalamy42/travail_seif
aelalamy42 Dec 4, 2023
e0dfb26
merge main
aelalamy42 Dec 4, 2023
0796d58
merge main
aelalamy42 Dec 4, 2023
e174292
merge main
aelalamy42 Dec 4, 2023
9282404
not working for now
aelalamy42 Dec 4, 2023
6ee4be4
add comments
aelalamy42 Dec 4, 2023
3780cd1
finish the torch implementation of MLA
aelalamy42 Dec 5, 2023
5717b7c
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Dec 5, 2023
f2c3c29
Start interface for trainable coded aperture.
ebezzam Dec 6, 2023
cd76a3c
Update trainable mask interface.
ebezzam Dec 6, 2023
21f0d08
beginning of multilens in trainable
aelalamy42 Dec 6, 2023
bc15721
Merge pull request #5 from aelalamy42/Work_Ghita
aelalamy42 Dec 6, 2023
73bb229
Improve trainable mask API.
ebezzam Dec 6, 2023
5e1f8cc
Fix MURA.
ebezzam Dec 6, 2023
2ceca9d
merge main
aelalamy42 Dec 7, 2023
ee71705
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Dec 7, 2023
922c866
matching names
aelalamy42 Dec 7, 2023
6ae7f21
Merge pull request #7 from aelalamy42/LCAV-trainable_amplitude_mask
aelalamy42 Dec 7, 2023
177a6e9
beginning of trainable multilensArray
aelalamy42 Dec 7, 2023
218ade8
project fonction, TrainableMultiLensArray constru
Ghita2002 Dec 7, 2023
a7847b8
update project
aelalamy42 Dec 7, 2023
ab3f233
Merge pull request #8 from aelalamy42/Work_Ghita
aelalamy42 Dec 7, 2023
3ed5fee
added TrainableHeightVarying in lensless/hardware/trainable_mask.py
Seif-Wessam Dec 7, 2023
6b66d4f
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Dec 10, 2023
89e73de
train_multilens_array.yaml
aelalamy42 Dec 10, 2023
72636a9
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Dec 10, 2023
54e0ab2
cleaner code
aelalamy42 Dec 10, 2023
37b4b11
Merge pull request #9 from aelalamy42/Work_Ghita
aelalamy42 Dec 10, 2023
b13a6f6
minor typo
aelalamy42 Dec 10, 2023
5292ea8
minor change in TrainableHeightVarying
Seif-Wessam Dec 10, 2023
85518e5
minor typo
aelalamy42 Dec 10, 2023
8adea38
changes to TrainableHeightVarying
Seif-Wessam Dec 10, 2023
100777c
update mask too
aelalamy42 Dec 10, 2023
f0b631a
debug
aelalamy42 Dec 10, 2023
b73f406
debug
aelalamy42 Dec 10, 2023
ffa42fd
debug
aelalamy42 Dec 10, 2023
1eb99d8
debug
aelalamy42 Dec 10, 2023
aa3d71d
change to TrainableHeightVaryingMask
Seif-Wessam Dec 10, 2023
f5df5a7
limit number of files
aelalamy42 Dec 10, 2023
529e3ad
typo fix
Seif-Wessam Dec 11, 2023
4d78a55
little changes
aelalamy42 Dec 11, 2023
26bc878
Merge branch 'main' of github.com:aelalamy42/LenslessPiCam_ML4Science…
aelalamy42 Dec 11, 2023
177de56
changesé
aelalamy42 Dec 11, 2023
52eab65
Merge pull request #10 from aelalamy42/travail_seif
aelalamy42 Dec 11, 2023
603ce44
add yaml file for height varying
aelalamy42 Dec 11, 2023
61651cf
Fix coded aperture training (fashion mnist).
ebezzam Dec 13, 2023
db6244d
few changes
aelalamy42 Dec 14, 2023
4444859
Merge pull request #11 from LCAV/trainable_amplitude_mask
aelalamy42 Dec 14, 2023
a374156
Set coded aperture optimization to grayscale.
ebezzam Dec 15, 2023
3d7c3b8
Merge pull request #12 from LCAV/trainable_amplitude_mask
aelalamy42 Dec 15, 2023
3751448
minor changes
aelalamy42 Dec 15, 2023
6205c5c
merge main
aelalamy42 Dec 15, 2023
359218d
push new changes
aelalamy42 Dec 15, 2023
ce8c504
torch.no_grad and other changes
aelalamy42 Dec 15, 2023
21323dd
new optimization of multilens heightmap computation
aelalamy42 Dec 16, 2023
54535c9
trying on google colab
aelalamy42 Dec 16, 2023
f475fa2
trying on google colab
aelalamy42 Dec 16, 2023
ebf159c
commit back
aelalamy42 Dec 16, 2023
fc18147
check that on GPU it works
aelalamy42 Dec 18, 2023
6fb1ef3
small changes
aelalamy42 Dec 18, 2023
e850991
small debug
aelalamy42 Dec 18, 2023
1932e35
small debug
aelalamy42 Dec 18, 2023
29f0126
Correctly set torch device.
ebezzam Dec 18, 2023
d2cc322
Move prep trainable mask into package.
ebezzam Dec 18, 2023
321154f
Set wavelength and optimizer param through config.
ebezzam Dec 18, 2023
28356c3
Merge branch 'trainable_amplitude_mask' of github.com:LCAV/LenslessPi…
aelalamy42 Dec 18, 2023
73befd0
Merge
aelalamy42 Dec 18, 2023
10c2179
changes
aelalamy42 Dec 18, 2023
951c1ae
changes
aelalamy42 Dec 18, 2023
8bb09f4
changes
aelalamy42 Dec 18, 2023
cf6be5c
changes
aelalamy42 Dec 18, 2023
5896496
changes
aelalamy42 Dec 18, 2023
953bcdc
changes
aelalamy42 Dec 18, 2023
807dff3
changes
aelalamy42 Dec 18, 2023
99b00d0
changes
aelalamy42 Dec 18, 2023
9be48e7
changes
aelalamy42 Dec 18, 2023
dd3cc2c
changes
aelalamy42 Dec 18, 2023
a5607eb
changes
aelalamy42 Dec 18, 2023
f256f5c
changes
aelalamy42 Dec 18, 2023
b0a3f64
changes
aelalamy42 Dec 18, 2023
ba8975c
changes
aelalamy42 Dec 18, 2023
db5e7a9
changes
aelalamy42 Dec 18, 2023
20fefa2
changes
aelalamy42 Dec 18, 2023
d155b70
changes
aelalamy42 Dec 18, 2023
f4f13d6
changes
aelalamy42 Dec 18, 2023
8c35575
changes
aelalamy42 Dec 18, 2023
9bdf313
changes
aelalamy42 Dec 18, 2023
c242ebd
debugging session
aelalamy42 Dec 19, 2023
b28802e
Update README.rst
aelalamy42 Dec 20, 2023
63485c0
Merge pull request #14 from aelalamy42/aelalamy42-patch-2
aelalamy42 Dec 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/psf.tiff
Binary file not shown.
122 changes: 122 additions & 0 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
from waveprop.noise import add_shot_noise
from lensless.hardware.sensor import VirtualSensor
from lensless.utils.image import resize
from matplotlib import pyplot as plt

try:
import torch
@@ -321,6 +322,125 @@ def simulate(self, obj, snr_db=20):

return meas

class MultiLensArray(Mask):
"""
Multi-lens array mask.
"""
def __init__(
self, N = None, radius = None, loc = None, refractive_index = 1.2, design_wv=532e-9, seed = 0, min_height=1e-3, **kwargs
):
"""
Multi-lens array mask constructor.

Parameters
----------
N: int
Number of lenses
radius: array_like
Radius of the lenses (m)
loc: array_like of tuples
Location of the lenses (m)
refractive_index: float
Refractive index of the mask substrate. Default is 1.2.
wavelength: float
seed: int
Seed for the random number generator. Default is 0.
min_height: float
Minimum height of the lenses (m). Default is 1e-3.
"""
self.N = N
self.radius = radius
self.loc = loc
self.refractive_index = refractive_index
self.wavelength = design_wv
self.seed = seed
self.min_height = min_height

if self.radius is not None:
assert self.loc is not None
assert len(self.radius) == len(self.loc)
self.N = len(self.radius)
circles = np.array([(self.loc[i][0], self.loc[i][1], self.radius[i]) for i in range(self.N)])
assert MultiLensArray.no_circle_overlap(circles)
else:
assert self.N is not None
np.random.seed(self.seed)
self.radius = np.random.uniform(self.min_height, 20, self.N) #TODO: check if it is the right way to do it
assert self.N == len(self.radius)
super().__init__(**kwargs)



@staticmethod
def no_circle_overlap(circles):
"""Check if any circle in the list overlaps with another."""
for i in range(len(circles)):
if MultiLensArray.does_circle_overlap(circles[i+1:], circles[i][0], circles[i][1], circles[i][2]):
return False
return True

@staticmethod
def does_circle_overlap(circles, x, y, r):
"""Check if a circle overlaps with any in the list."""
for (cx, cy, cr) in circles:
if np.sqrt((x - cx)**2 + (y - cy)**2) < r/2 + cr/2:
return True
return False

@staticmethod
def place_spheres_on_plane(width, height, radius, max_attempts=1000):
"""Try to place circles on a 2D plane."""
placed_circles = []
radius_sorted = sorted(radius, reverse=True) # Place larger circles first

for r in radius_sorted:
placed = False
for _ in range(max_attempts):
x = np.random.uniform(r, width - r)
y = np.random.uniform(r, height - r)

if not MultiLensArray.does_circle_overlap(placed_circles, x, y, r):
placed_circles.append((x, y, r))
placed = True
print(f"Placed circle with rad {r}, and center ({x}, {y})")
break

if not placed:
print(f"Failed to place circle with rad {r}")
continue

placed_circles = np.array(placed_circles)
circles = placed_circles[:, :2]
radius = placed_circles[:, 2]
return circles, radius

def create_mask(self):
self.loc, self.radius = MultiLensArray.place_spheres_on_plane(self.resolution[0], self.resolution[1], self.radius)
height = self.create_height_map(self.radius, self.loc)
phi = (height * (self.refractive_index - 1) * 2 * np.pi / self.wavelength) #% (2 * np.pi) ? Makes it have some noisy values instead of a continuous sphere
fig, ax = plt.subplots()
im = ax.imshow(phi, cmap="gray")
fig.colorbar(im, ax=ax, shrink=0.5, aspect=5)
plt.show()
self.mask = np.exp(1j * phi)
print(np.angle(self.mask[51, 321]), " ", phi[51, 321] - 2*np.pi)

def create_height_map(self, radius, locs):
height = np.full((self.resolution[0], self.resolution[1]), self.min_height)
for x in range(height.shape[0]):
for y in range(height.shape[1]):
height[x, y] += self.lens_contribution(radius, locs, x + 0.5, y + 0.5)
assert np.all(height >= self.min_height)
return height

def lens_contribution(self, radius, locs, x, y):
contribution = 0
for idx, loc in enumerate(locs):
if (x-loc[0])**2 + (y-loc[1])**2 < radius[idx]**2:
contribution = np.sqrt(radius[idx]**2 - (x-loc[0])**2 - (y-loc[1])**2)
return contribution
return contribution


class PhaseContour(Mask):
"""
@@ -354,6 +474,7 @@ def __init__(
self.refractive_index = refractive_index
self.n_iter = n_iter
self.design_wv = design_wv


super().__init__(**kwargs)

@@ -410,6 +531,7 @@ def phase_retrieval(target_psf, wv, d1, dz, n=1.2, n_iter=10, height_map=False):
n_iter: int
Number of iterations. Default value is 10.
"""

M_p = np.sqrt(target_psf)

if hasattr(d1, "__len__"):
17 changes: 15 additions & 2 deletions test/test_masks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture
from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture, MultiLensArray
from lensless.eval.metric import mse, psnr, ssim
from waveprop.fresnel import fresnel_conv
from matplotlib import pyplot as plt


resolution = np.array([380, 507])
@@ -75,7 +76,6 @@ def test_classmethod():
assert np.all(mask1.mask.shape == resolution)
desired_psf_shape = np.array(tuple(resolution) + (len(mask1.psf_wavelength),))
assert np.all(mask1.psf.shape == desired_psf_shape)

mask2 = PhaseContour.from_sensor(
sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz
)
@@ -90,6 +90,19 @@ def test_classmethod():
desired_psf_shape = np.array(tuple(resolution) + (len(mask3.psf_wavelength),))
assert np.all(mask3.psf.shape == desired_psf_shape)

mask4 = MultiLensArray.from_sensor(
sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz, N=10#radius=np.array([10, 25]), loc=np.array([[10.1, 11.3], [56.5, 89.2]])
)
assert np.all(mask4.mask.shape == resolution)
desired_psf_shape = np.array(tuple(resolution) + (len(mask4.psf_wavelength),))
assert np.all(mask3.psf.shape == desired_psf_shape)

z = np.abs(np.angle(mask4.mask))
assert np.all(z > 0)
fig, ax = plt.subplots()
#im = ax.imshow(z, cmap="gray")
#fig.colorbar(im, ax=ax, shrink=0.5, aspect=5)
#plt.show()

if __name__ == "__main__":
test_flatcam()