Skip to content

Commit 310f0b6

Browse files
authored
Merge pull request #79 from kc-ml2/DEV/main
Dev/main
2 parents 1affa35 + cb9a618 commit 310f0b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+4244
-3620
lines changed

QA/1D_grating_in_2D_pattern.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

QA/1d_pattern_in_1dc_and_2d.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# This demo shows a case with 1D grating and TM polarization.
2+
# If phi is set to 'None', this will use 1D TETM formulation (without azimuthal rotation, phi == 0)
3+
# But if phi is set to '0', then the simulation will be taken for 1D conical or 2D case which is general but slower.
4+
5+
import numpy as np
6+
from time import time
7+
8+
from meent import call_mee
9+
10+
11+
def compare():
12+
backend = 0
13+
pol = 1 # 0: TE, 1: TM
14+
15+
n_top = 1 # n_incidence
16+
n_bot = 1 # n_transmission
17+
18+
theta = 1E-10 # angle of incidence in radian
19+
20+
wavelength = 300 # wavelength
21+
thickness = [460, 22]
22+
period = [700, 700]
23+
fto = [100, 0]
24+
25+
ucell_1d = np.array([
26+
[
27+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
28+
],
29+
[
30+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
31+
],
32+
])
33+
ucell_2d = np.array([
34+
[
35+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
36+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
37+
],
38+
[
39+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
40+
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
41+
],
42+
])
43+
44+
mee = call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, fto=fto,
45+
wavelength=wavelength, period=period, thickness=thickness)
46+
47+
# 1D
48+
mee.phi = None # which is default
49+
mee.ucell = ucell_1d
50+
51+
t0_1d = time()
52+
res = mee.conv_solve().res
53+
t1_1d = time()
54+
de_ri1, de_ti1 = res.de_ri, res.de_ti
55+
print('1D (de_ri, de_ti): ', de_ri1, de_ti1)
56+
57+
# 1D conical
58+
mee.phi = 0
59+
t0_1dc = time()
60+
res = mee.conv_solve().res
61+
t1_1dc = time()
62+
de_ri1c, de_ti1c = res.de_ri, res.de_ti
63+
print('1Dc (de_ri, de_ti): ', de_ri1c, de_ti1c)
64+
65+
# 2D
66+
mee.phi = 0
67+
t0_2d = time()
68+
mee.ucell = ucell_2d
69+
res = mee.conv_solve().res
70+
t1_2d = time()
71+
de_ri2, de_ti2 = res.de_ri, res.de_ti
72+
print('2D (de_ri, de_ti): ', de_ri2, de_ti2)
73+
74+
print('time for 1D formulation: ', t1_1d-t0_1d, 's')
75+
print('time for 1Dc formulation: ', t1_1dc-t0_1dc, 's')
76+
print('time for 2D formulation: ', t1_2d-t0_2d, 's')
77+
print('Simulation Difference between 1D and 1Dc formulation: ',
78+
np.linalg.norm(de_ri1 - de_ri1c), np.linalg.norm(de_ti1 - de_ti1c))
79+
print('Simulation Difference between 1D and 2D formulation: ',
80+
np.linalg.norm(de_ri1 - de_ri2), np.linalg.norm(de_ti1 - de_ti2))
81+
82+
print('Simulation Difference between 1Dc and 2D formulation: ',
83+
np.linalg.norm(de_ri1c - de_ri2), np.linalg.norm(de_ti1c - de_ti2))
84+
85+
86+
if __name__ == '__main__':
87+
compare()

QA/autograd_complex_ucell.py renamed to QA/autodiff_raster1.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
import meent
11-
from meent.on_torch.optimizer.loss import LossDeflector
11+
1212

1313
type_complex = 0
1414
device = 0
@@ -48,7 +48,19 @@
4848

4949
pois = ['ucell', 'thickness'] # Parameter Of Interests
5050
forward = jmee.conv_solve
51-
loss_fn = LossDeflector(x_order=0, y_order=0)
51+
52+
53+
class Loss:
54+
def __call__(self, meent_result, *args, **kwargs):
55+
res_psi, res_te, res_ti = meent_result.res, meent_result.res_te_inc, meent_result.res_tm_inc
56+
de_ti = res_psi.de_ti
57+
center = [a // 2 for a in de_ti.shape]
58+
res = de_ti[center[0], center[1]+1]
59+
60+
return res
61+
62+
63+
loss_fn = Loss()
5264

5365
# case 1: Gradient
5466
grad_j = jmee.grad(pois, forward, loss_fn)
@@ -58,7 +70,7 @@
5870
print('thickness gradient:')
5971
print(grad_j['thickness'])
6072

61-
optimizer = optax.sgd(learning_rate=1e-2)
73+
optimizer = optax.sgd(learning_rate=1E2)
6274
t0 = time.time()
6375
res_j = jmee.fit(pois, forward, loss_fn, optimizer, iteration=iteration)
6476
print('Time JAX', time.time() - t0)
@@ -74,7 +86,6 @@
7486
thickness=thickness, type_complex=type_complex, device=device)
7587

7688
forward = tmee.conv_solve
77-
loss_fn = LossDeflector(x_order=0) # predefined in meent
7889

7990
grad_t = tmee.grad(pois, forward, loss_fn)
8091
print('ucell gradient:')
@@ -83,7 +94,7 @@
8394
print(grad_t['thickness'])
8495

8596
opt_torch = torch.optim.SGD
86-
opt_options = {'lr': 1E-2}
97+
opt_options = {'lr': 1E2}
8798

8899
t0 = time.time()
89100
res_t = tmee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=iteration)
@@ -102,6 +113,6 @@
102113

103114
print('End')
104115

105-
# Note that the gradient in JAX is conjugated.
116+
# Note that the gradient in JAX is conjugation of PyTorch's.
106117
# https://github.com/google/jax/issues/4891
107118
# https://pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers

QA/autodiff_raster2.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import jax
2+
import torch
3+
4+
import jax.numpy as jnp
5+
import numpy as np
6+
7+
from time import time
8+
9+
from meent import call_mee
10+
11+
12+
def load_setting():
13+
pol = 1 # 0: TE, 1: TM
14+
15+
n_top = 1 # n_incidence
16+
n_bot = 1 # n_transmission
17+
18+
theta = 0 * np.pi / 180
19+
phi = 0 * np.pi / 180
20+
21+
wavelength = 900
22+
23+
fto = [5, 5]
24+
25+
period = [1000, 1000]
26+
thickness = [1120]
27+
28+
ucell = np.array([[[2.58941352 + 0.47745679j, 4.17771602 + 0.88991205j,
29+
2.04255624 + 2.23670125j, 2.50478974 + 2.05242759j,
30+
3.32747593 + 2.3854387j],
31+
[2.80118605 + 0.53053715j, 4.46498861 + 0.10812571j,
32+
3.99377545 + 1.0441131j, 3.10728537 + 0.6637353j,
33+
4.74697849 + 0.62841253j],
34+
[3.80944424 + 2.25899274j, 3.70371553 + 1.32586402j,
35+
3.8011133 + 1.49939415j, 3.14797238 + 2.91158289j,
36+
4.3085404 + 2.44344691j],
37+
[2.22510179 + 2.86017146j, 2.36613053 + 2.82270351j,
38+
4.5087168 + 0.2035904j, 3.15559949 + 2.55311298j,
39+
4.29394604 + 0.98362617j],
40+
[3.31324163 + 2.77590131j, 2.11744834 + 1.65894674j,
41+
3.59347907 + 1.28895345j, 3.85713467 + 1.90714056j,
42+
2.93805426 + 2.63385392j]]])
43+
ucell = ucell.real
44+
45+
type_complex = 0
46+
device = 0
47+
48+
setting = {'pol': pol, 'n_top': n_top, 'n_bot': n_bot, 'theta': theta, 'phi': phi, 'fto': fto,
49+
'wavelength': wavelength, 'period': period, 'ucell': ucell, 'thickness': thickness, 'device': device,
50+
'type_complex': type_complex}
51+
52+
return setting
53+
54+
55+
def optimize_jax(setting):
56+
ucell = setting['ucell']
57+
58+
mee = call_mee(backend=1, **setting)
59+
60+
@jax.jit
61+
def grad_loss(ucell):
62+
mee.ucell = ucell
63+
res = mee.conv_solve().res
64+
de_ri, de_ti = res.de_ri, res.de_ti
65+
66+
loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]
67+
68+
return loss
69+
70+
def grad_numerical(ucell, delta):
71+
grad_arr = jnp.zeros(ucell.shape, dtype=ucell.dtype)
72+
73+
@jax.jit
74+
def compute(ucell):
75+
mee.ucell = ucell
76+
result = mee.conv_solve()
77+
de_ti = result.res.de_ti
78+
loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]
79+
80+
return loss
81+
82+
for layer in range(ucell.shape[0]):
83+
for r in range(ucell.shape[1]):
84+
for c in range(ucell.shape[2]):
85+
ucell_delta_m = ucell.copy()
86+
ucell_delta_m[layer, r, c] -= delta
87+
mee.ucell = ucell_delta_m
88+
de_ti_delta_m = compute(ucell_delta_m, )
89+
90+
ucell_delta_p = ucell.copy()
91+
ucell_delta_p[layer, r, c] += delta
92+
mee.ucell = ucell_delta_p
93+
de_ti_delta_p = compute(ucell_delta_p, )
94+
95+
grad_numeric = (de_ti_delta_p - de_ti_delta_m) / (2 * delta)
96+
grad_arr = grad_arr.at[layer, r, c].set(grad_numeric)
97+
98+
return grad_arr
99+
100+
jax.grad(grad_loss)(ucell) # Dry run for jit compilation. This is to make time comparison fair.
101+
t0 = time()
102+
grad_ad = jax.grad(grad_loss)(ucell)
103+
t_ad = time() - t0
104+
print('JAX grad_ad:\n', grad_ad)
105+
t0 = time()
106+
grad_nume = grad_numerical(ucell, 1E-6)
107+
t_nume = time() - t0
108+
print('JAX grad_numeric:\n', grad_nume)
109+
print('JAX norm of difference: ', jnp.linalg.norm(grad_nume - grad_ad) / grad_nume.size)
110+
return t_ad, t_nume
111+
112+
113+
def optimize_torch(setting):
114+
mee = call_mee(backend=2, **setting)
115+
116+
mee.ucell.requires_grad = True
117+
118+
t0 = time()
119+
res = mee.conv_solve().res
120+
de_ri, de_ti = res.de_ri, res.de_ti
121+
122+
loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]
123+
124+
loss.backward()
125+
grad_ad = mee.ucell.grad
126+
t_ad = time() - t0
127+
128+
def grad_numerical(ucell, delta):
129+
ucell.requires_grad = False
130+
grad_arr = torch.zeros(ucell.shape, dtype=ucell.dtype)
131+
132+
for layer in range(ucell.shape[0]):
133+
for r in range(ucell.shape[1]):
134+
for c in range(ucell.shape[2]):
135+
ucell_delta_m = ucell.clone().detach()
136+
ucell_delta_m[layer, r, c] -= delta
137+
mee.ucell = ucell_delta_m
138+
res = mee.conv_solve().res
139+
de_ri_delta_m, de_ti_delta_m = res.de_ri, res.de_ti
140+
141+
ucell_delta_p = ucell.clone().detach()
142+
ucell_delta_p[layer, r, c] += delta
143+
mee.ucell = ucell_delta_p
144+
res = mee.conv_solve().res
145+
de_ri_delta_p, de_ti_delta_p = res.de_ri, res.de_ti
146+
147+
cy, cx = np.array(de_ti_delta_p.shape) // 2
148+
grad_numeric = (de_ti_delta_p[cy, cx] - de_ti_delta_m[cy, cx]) / (2 * delta)
149+
grad_arr[layer, r, c] = grad_numeric
150+
151+
return grad_arr
152+
153+
t0 = time()
154+
grad_nume = grad_numerical(mee.ucell, 1E-6)
155+
t_nume = time() - t0
156+
157+
print('Torch grad_ad:\n', grad_ad)
158+
print('Torch grad_numeric:\n', grad_nume)
159+
print('torch.norm: ', torch.linalg.norm(grad_nume - grad_ad) / grad_nume.numel())
160+
return t_ad, t_nume
161+
162+
163+
if __name__ == '__main__':
164+
setting = load_setting()
165+
166+
print('JaxMeent')
167+
j_t_ad, j_t_nume = optimize_jax(setting)
168+
print('TorchMeent')
169+
t_t_ad, t_t_nume = optimize_torch(setting)
170+
171+
print(f'Time for Backprop, JAX, AD: {j_t_ad} s, Numerical: {j_t_nume} s')
172+
print(f'Time for Backprop, Torch, AD: {t_t_ad} s, Numerical: {t_t_nume} s')

0 commit comments

Comments
 (0)