Skip to content

Commit e7a2980

Browse files
authored
Merge pull request #74 from kc-ml2/DEV/main
Dev/main
2 parents 33a7fe3 + cc28149 commit e7a2980

File tree

7 files changed

+59
-30
lines changed

7 files changed

+59
-30
lines changed

meent/on_jax/emsolver/_base.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
278278
elif self.pol == 1:
279279
E_conv_i = jnp.linalg.inv(E_conv)
280280
B = Kx @ E_conv_i @ Kx - jnp.eye(E_conv.shape[0]).astype(self.type_complex)
281-
o_E_conv_i = jnp.linalg.inv(o_E_conv)
282-
eigenvalues, W = eig(o_E_conv_i @ B, type_complex=self.type_complex, perturbation=self.perturbation,
281+
# o_E_conv_i = jnp.linalg.inv(o_E_conv)
282+
283+
eigenvalues, W = eig(E_conv @ B, type_complex=self.type_complex, perturbation=self.perturbation,
283284
device=self.device)
284285
eigenvalues += 0j # to get positive square root
285286
q = eigenvalues ** 0.5
286287
Q = jnp.diag(q)
287-
V = o_E_conv @ W @ Q
288+
# V = o_E_conv @ W @ Q
289+
V = E_conv_i @ W @ Q
288290

289291
else:
290292
raise ValueError
@@ -345,11 +347,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
345347
for layer_index in range(count)[::-1]:
346348

347349
E_conv = E_conv_all[layer_index]
348-
o_E_conv = o_E_conv_all[layer_index]
350+
# o_E_conv = o_E_conv_all[layer_index]
351+
o_E_conv = None
352+
349353
d = self.thickness[layer_index]
350354

351355
E_conv_i = jnp.linalg.inv(E_conv)
352-
o_E_conv_i = jnp.linalg.inv(o_E_conv)
356+
# o_E_conv_i = jnp.linalg.inv(o_E_conv)
357+
o_E_conv_i = None
353358

354359
if self.algo == 'TMM':
355360
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
@@ -418,11 +423,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
418423
# From the last layer
419424
for layer_index in range(count)[::-1]:
420425
E_conv = E_conv_all[layer_index]
421-
o_E_conv = o_E_conv_all[layer_index]
426+
# o_E_conv = o_E_conv_all[layer_index]
427+
o_E_conv = None
428+
422429
d = self.thickness[layer_index]
423430

424431
E_conv_i = jnp.linalg.inv(E_conv)
425-
o_E_conv_i = jnp.linalg.inv(o_E_conv)
432+
# o_E_conv_i = jnp.linalg.inv(o_E_conv)
433+
o_E_conv_i = None
426434

427435
if self.algo == 'TMM':
428436
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv,

meent/on_jax/emsolver/transfer_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph
135135
B_i = jnp.linalg.inv(B)
136136

137137
to_decompose_W_1 = (ky/k0) ** 2 * I + A
138-
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
138+
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
139+
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv
139140

140141
eigenvalues_1, W_1 = eig(to_decompose_W_1, type_complex=type_complex, perturbation=perturbation, device=device)
141142
eigenvalues_2, W_2 = eig(to_decompose_W_2, type_complex=type_complex, perturbation=perturbation, device=device)

meent/on_numpy/emsolver/_base.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
226226
# From the last layer
227227
for layer_index in range(count)[::-1]:
228228
E_conv = E_conv_all[layer_index]
229-
o_E_conv = o_E_conv_all[layer_index]
229+
# o_E_conv = o_E_conv_all[layer_index]
230+
230231
d = self.thickness[layer_index]
231232

232233
if self.pol == 0:
@@ -241,13 +242,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
241242
elif self.pol == 1:
242243
E_conv_i = np.linalg.inv(E_conv)
243244
B = Kx @ E_conv_i @ Kx - np.eye(E_conv.shape[0], dtype=self.type_complex)
244-
o_E_conv_i = np.linalg.inv(o_E_conv)
245+
# o_E_conv_i = np.linalg.inv(o_E_conv)
245246

246-
eigenvalues, W = np.linalg.eig(o_E_conv_i @ B)
247+
eigenvalues, W = np.linalg.eig(E_conv @ B)
247248
eigenvalues += 0j # to get positive square root
248249
q = eigenvalues ** 0.5
249250
Q = np.diag(q)
250-
V = o_E_conv @ W @ Q
251+
# V = o_E_conv @ W @ Q
252+
V = E_conv_i @ W @ Q
253+
251254
else:
252255
raise ValueError
253256

@@ -305,11 +308,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
305308
for layer_index in range(count)[::-1]:
306309

307310
E_conv = E_conv_all[layer_index]
308-
o_E_conv = o_E_conv_all[layer_index]
311+
# o_E_conv = o_E_conv_all[layer_index]
312+
o_E_conv = None
313+
309314
d = self.thickness[layer_index]
310315

311316
E_conv_i = np.linalg.inv(E_conv)
312-
o_E_conv_i = np.linalg.inv(o_E_conv)
317+
# o_E_conv_i = np.linalg.inv(o_E_conv)
318+
o_E_conv_i = None
313319

314320
if self.algo == 'TMM':
315321
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
@@ -375,11 +381,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
375381
# From the last layer
376382
for layer_index in range(count)[::-1]:
377383
E_conv = E_conv_all[layer_index]
378-
o_E_conv = o_E_conv_all[layer_index]
384+
# o_E_conv = o_E_conv_all[layer_index]
385+
o_E_conv = None
386+
379387
d = self.thickness[layer_index]
380388

381389
E_conv_i = np.linalg.inv(E_conv)
382-
o_E_conv_i = np.linalg.inv(o_E_conv)
390+
# o_E_conv_i = np.linalg.inv(o_E_conv)
391+
o_E_conv_i = None
383392

384393
if self.algo == 'TMM':
385394
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex)
@@ -393,7 +402,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
393402
self.layer_info_list.append(layer_info)
394403

395404
elif self.algo == 'SMM':
396-
W, V, q = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
405+
W, V, q = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
397406
A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q)
398407
else:
399408
raise ValueError
@@ -405,7 +414,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
405414
self.T1 = big_T1
406415

407416
elif self.algo == 'SMM':
408-
de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I,
417+
de_ri, de_ti = scattering_2d_3(ff_xy, Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I,
409418
self.pol, self.theta, self.phi, self.fourier_order)
410419
else:
411420
raise ValueError

meent/on_numpy/emsolver/transfer_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph
121121
B_i = np.linalg.inv(B)
122122

123123
to_decompose_W_1 = (ky/k0) ** 2 * I + A
124-
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
124+
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
125+
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv
125126

126127
eigenvalues_1, W_1 = np.linalg.eig(to_decompose_W_1)
127128
eigenvalues_2, W_2 = np.linalg.eig(to_decompose_W_2)

meent/on_torch/emsolver/_base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
274274
for layer_index in range(count)[::-1]:
275275

276276
E_conv = E_conv_all[layer_index]
277-
o_E_conv = o_E_conv_all[layer_index]
277+
# o_E_conv = o_E_conv_all[layer_index]
278+
278279
d = self.thickness[layer_index]
279280

280281
if self.pol == 0:
@@ -289,13 +290,14 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
289290
elif self.pol == 1:
290291
E_conv_i = torch.linalg.inv(E_conv)
291292
B = Kx @ E_conv_i @ Kx - torch.eye(E_conv.shape[0], device=self.device, dtype=self.type_complex)
292-
o_E_conv_i = torch.linalg.inv(o_E_conv)
293+
# o_E_conv_i = torch.linalg.inv(o_E_conv)
293294

294295
Eig.perturbation = self.perturbation
295-
eigenvalues, W = Eig.apply(o_E_conv_i @ B)
296+
eigenvalues, W = Eig.apply(E_conv @ B)
296297
q = eigenvalues ** 0.5
297298
Q = torch.diag(q)
298-
V = o_E_conv @ W @ Q
299+
# V = o_E_conv @ W @ Q
300+
V = E_conv_i @ W @ Q
299301

300302
else:
301303
raise ValueError
@@ -355,11 +357,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
355357
for layer_index in range(count)[::-1]:
356358

357359
E_conv = E_conv_all[layer_index]
358-
o_E_conv = o_E_conv_all[layer_index]
360+
# o_E_conv = o_E_conv_all[layer_index]
361+
o_E_conv = None
362+
359363
d = self.thickness[layer_index]
360364

361365
E_conv_i = torch.linalg.inv(E_conv)
362-
o_E_conv_i = torch.linalg.inv(o_E_conv)
366+
# o_E_conv_i = torch.linalg.inv(o_E_conv)
367+
o_E_conv_i = None
363368

364369
if self.algo == 'TMM':
365370
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\
@@ -429,10 +434,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
429434
for layer_index in range(count)[::-1]:
430435

431436
E_conv = E_conv_all[layer_index]
432-
o_E_conv = o_E_conv_all[layer_index]
437+
# o_E_conv = o_E_conv_all[layer_index]
438+
o_E_conv = None
439+
433440
d = self.thickness[layer_index]
441+
434442
E_conv_i = torch.linalg.inv(E_conv)
435-
o_E_conv_i = torch.linalg.inv(o_E_conv)
443+
# o_E_conv_i = torch.linalg.inv(o_E_conv)
444+
o_E_conv_i = None
436445

437446
if self.algo == 'TMM':
438447
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv,
@@ -447,7 +456,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
447456
self.layer_info_list.append(layer_info)
448457

449458
elif self.algo == 'SMM':
450-
W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
459+
W, V, LAMBDA = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
451460
A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA)
452461
else:
453462
raise ValueError

meent/on_torch/emsolver/transfer_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, bi
134134
B_i = torch.linalg.inv(B)
135135

136136
to_decompose_W_1 = (ky/k0) ** 2 * I + A
137-
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
137+
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
138+
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv
138139

139140
Eig.perturbation = perturbation
140141
eigenvalues_1, W_1 = Eig.apply(to_decompose_W_1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
}
1313
setup(
1414
name='meent',
15-
version='0.9.12',
15+
version='0.9.13',
1616
url='https://github.com/kc-ml2/meent',
1717
author='KC ML2',
1818
author_email='[email protected]',

0 commit comments

Comments
 (0)