|
| 1 | +import jax |
| 2 | +import torch |
| 3 | + |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from time import time |
| 8 | + |
| 9 | +from meent import call_mee |
| 10 | + |
| 11 | + |
| 12 | +def load_setting(): |
| 13 | + pol = 1 # 0: TE, 1: TM |
| 14 | + |
| 15 | + n_top = 1 # n_incidence |
| 16 | + n_bot = 1 # n_transmission |
| 17 | + |
| 18 | + theta = 0 * np.pi / 180 |
| 19 | + phi = 0 * np.pi / 180 |
| 20 | + |
| 21 | + wavelength = 900 |
| 22 | + |
| 23 | + fto = [5, 5] |
| 24 | + |
| 25 | + period = [1000, 1000] |
| 26 | + thickness = [1120] |
| 27 | + |
| 28 | + ucell = np.array([[[2.58941352 + 0.47745679j, 4.17771602 + 0.88991205j, |
| 29 | + 2.04255624 + 2.23670125j, 2.50478974 + 2.05242759j, |
| 30 | + 3.32747593 + 2.3854387j], |
| 31 | + [2.80118605 + 0.53053715j, 4.46498861 + 0.10812571j, |
| 32 | + 3.99377545 + 1.0441131j, 3.10728537 + 0.6637353j, |
| 33 | + 4.74697849 + 0.62841253j], |
| 34 | + [3.80944424 + 2.25899274j, 3.70371553 + 1.32586402j, |
| 35 | + 3.8011133 + 1.49939415j, 3.14797238 + 2.91158289j, |
| 36 | + 4.3085404 + 2.44344691j], |
| 37 | + [2.22510179 + 2.86017146j, 2.36613053 + 2.82270351j, |
| 38 | + 4.5087168 + 0.2035904j, 3.15559949 + 2.55311298j, |
| 39 | + 4.29394604 + 0.98362617j], |
| 40 | + [3.31324163 + 2.77590131j, 2.11744834 + 1.65894674j, |
| 41 | + 3.59347907 + 1.28895345j, 3.85713467 + 1.90714056j, |
| 42 | + 2.93805426 + 2.63385392j]]]) |
| 43 | + ucell = ucell.real |
| 44 | + |
| 45 | + type_complex = 0 |
| 46 | + device = 0 |
| 47 | + |
| 48 | + setting = {'pol': pol, 'n_top': n_top, 'n_bot': n_bot, 'theta': theta, 'phi': phi, 'fto': fto, |
| 49 | + 'wavelength': wavelength, 'period': period, 'ucell': ucell, 'thickness': thickness, 'device': device, |
| 50 | + 'type_complex': type_complex} |
| 51 | + |
| 52 | + return setting |
| 53 | + |
| 54 | + |
| 55 | +def optimize_jax(setting): |
| 56 | + ucell = setting['ucell'] |
| 57 | + |
| 58 | + mee = call_mee(backend=1, **setting) |
| 59 | + |
| 60 | + @jax.jit |
| 61 | + def grad_loss(ucell): |
| 62 | + mee.ucell = ucell |
| 63 | + res = mee.conv_solve().res |
| 64 | + de_ri, de_ti = res.de_ri, res.de_ti |
| 65 | + |
| 66 | + loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2] |
| 67 | + |
| 68 | + return loss |
| 69 | + |
| 70 | + def grad_numerical(ucell, delta): |
| 71 | + grad_arr = jnp.zeros(ucell.shape, dtype=ucell.dtype) |
| 72 | + |
| 73 | + @jax.jit |
| 74 | + def compute(ucell): |
| 75 | + mee.ucell = ucell |
| 76 | + result = mee.conv_solve() |
| 77 | + de_ti = result.res.de_ti |
| 78 | + loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2] |
| 79 | + |
| 80 | + return loss |
| 81 | + |
| 82 | + for layer in range(ucell.shape[0]): |
| 83 | + for r in range(ucell.shape[1]): |
| 84 | + for c in range(ucell.shape[2]): |
| 85 | + ucell_delta_m = ucell.copy() |
| 86 | + ucell_delta_m[layer, r, c] -= delta |
| 87 | + mee.ucell = ucell_delta_m |
| 88 | + de_ti_delta_m = compute(ucell_delta_m, ) |
| 89 | + |
| 90 | + ucell_delta_p = ucell.copy() |
| 91 | + ucell_delta_p[layer, r, c] += delta |
| 92 | + mee.ucell = ucell_delta_p |
| 93 | + de_ti_delta_p = compute(ucell_delta_p, ) |
| 94 | + |
| 95 | + grad_numeric = (de_ti_delta_p - de_ti_delta_m) / (2 * delta) |
| 96 | + grad_arr = grad_arr.at[layer, r, c].set(grad_numeric) |
| 97 | + |
| 98 | + return grad_arr |
| 99 | + |
| 100 | + jax.grad(grad_loss)(ucell) # Dry run for jit compilation. This is to make time comparison fair. |
| 101 | + t0 = time() |
| 102 | + grad_ad = jax.grad(grad_loss)(ucell) |
| 103 | + t_ad = time() - t0 |
| 104 | + print('JAX grad_ad:\n', grad_ad) |
| 105 | + t0 = time() |
| 106 | + grad_nume = grad_numerical(ucell, 1E-6) |
| 107 | + t_nume = time() - t0 |
| 108 | + print('JAX grad_numeric:\n', grad_nume) |
| 109 | + print('JAX norm of difference: ', jnp.linalg.norm(grad_nume - grad_ad) / grad_nume.size) |
| 110 | + return t_ad, t_nume |
| 111 | + |
| 112 | + |
| 113 | +def optimize_torch(setting): |
| 114 | + mee = call_mee(backend=2, **setting) |
| 115 | + |
| 116 | + mee.ucell.requires_grad = True |
| 117 | + |
| 118 | + t0 = time() |
| 119 | + res = mee.conv_solve().res |
| 120 | + de_ri, de_ti = res.de_ri, res.de_ti |
| 121 | + |
| 122 | + loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2] |
| 123 | + |
| 124 | + loss.backward() |
| 125 | + grad_ad = mee.ucell.grad |
| 126 | + t_ad = time() - t0 |
| 127 | + |
| 128 | + def grad_numerical(ucell, delta): |
| 129 | + ucell.requires_grad = False |
| 130 | + grad_arr = torch.zeros(ucell.shape, dtype=ucell.dtype) |
| 131 | + |
| 132 | + for layer in range(ucell.shape[0]): |
| 133 | + for r in range(ucell.shape[1]): |
| 134 | + for c in range(ucell.shape[2]): |
| 135 | + ucell_delta_m = ucell.clone().detach() |
| 136 | + ucell_delta_m[layer, r, c] -= delta |
| 137 | + mee.ucell = ucell_delta_m |
| 138 | + res = mee.conv_solve().res |
| 139 | + de_ri_delta_m, de_ti_delta_m = res.de_ri, res.de_ti |
| 140 | + |
| 141 | + ucell_delta_p = ucell.clone().detach() |
| 142 | + ucell_delta_p[layer, r, c] += delta |
| 143 | + mee.ucell = ucell_delta_p |
| 144 | + res = mee.conv_solve().res |
| 145 | + de_ri_delta_p, de_ti_delta_p = res.de_ri, res.de_ti |
| 146 | + |
| 147 | + cy, cx = np.array(de_ti_delta_p.shape) // 2 |
| 148 | + grad_numeric = (de_ti_delta_p[cy, cx] - de_ti_delta_m[cy, cx]) / (2 * delta) |
| 149 | + grad_arr[layer, r, c] = grad_numeric |
| 150 | + |
| 151 | + return grad_arr |
| 152 | + |
| 153 | + t0 = time() |
| 154 | + grad_nume = grad_numerical(mee.ucell, 1E-6) |
| 155 | + t_nume = time() - t0 |
| 156 | + |
| 157 | + print('Torch grad_ad:\n', grad_ad) |
| 158 | + print('Torch grad_numeric:\n', grad_nume) |
| 159 | + print('torch.norm: ', torch.linalg.norm(grad_nume - grad_ad) / grad_nume.numel()) |
| 160 | + return t_ad, t_nume |
| 161 | + |
| 162 | + |
| 163 | +if __name__ == '__main__': |
| 164 | + setting = load_setting() |
| 165 | + |
| 166 | + print('JaxMeent') |
| 167 | + j_t_ad, j_t_nume = optimize_jax(setting) |
| 168 | + print('TorchMeent') |
| 169 | + t_t_ad, t_t_nume = optimize_torch(setting) |
| 170 | + |
| 171 | + print(f'Time for Backprop, JAX, AD: {j_t_ad} s, Numerical: {j_t_nume} s') |
| 172 | + print(f'Time for Backprop, Torch, AD: {t_t_ad} s, Numerical: {t_t_nume} s') |
0 commit comments