Skip to content

Commit

Permalink
Merge pull request #464 from sh-tada/trans_tdiff_tada
Browse files Browse the repository at this point in the history
Fix opachord.py to avoid nan in the gradient of ArtTransPure model
  • Loading branch information
HajimeKawahara committed Jan 15, 2024
2 parents feaa201 + 410180b commit 974a489
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/exojax/spec/opachord.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def chord_geometric_matrix_lower(height, radius_lower):
"""
radius_upper = radius_lower + height
fac_left = jnp.sqrt(radius_upper[None, :]**2 - radius_lower[:, None]**2)
fac_right = jnp.sqrt(radius_lower[None, :]**2 - radius_lower[:, None]**2)
fac_left = jnp.sqrt(jnp.tril(radius_upper[None, :]**2 - radius_lower[:, None]**2))
fac_right = jnp.sqrt(jnp.tril(radius_lower[None, :]**2 - radius_lower[:, None]**2, k=-1))
raw_matrix = 2.0 * (fac_left - fac_right) / height
return jnp.tril(raw_matrix)
return raw_matrix


@jit
Expand All @@ -50,12 +50,10 @@ def chord_geometric_matrix(height, radius_lower):
radius_upper = radius_lower + height
radius_midpoint = radius_lower + height / 2.0

fac_left = radius_upper[None, :]**2 - radius_midpoint[:, None]**2
fac_right = radius_lower[None, :]**2 - radius_midpoint[:, None]**2
deep_element_correction = radius_lower**2 - radius_midpoint**2
fac_right = fac_right - jnp.diag(deep_element_correction)
fac_left = jnp.tril(radius_upper[None, :]**2 - radius_midpoint[:, None]**2)
fac_right = jnp.tril(radius_lower[None, :]**2 - radius_midpoint[:, None]**2, k=-1)
raw_matrix = 2.0 * (jnp.sqrt(fac_left) - jnp.sqrt(fac_right)) / height
return jnp.tril(raw_matrix)
return raw_matrix

@jit
def chord_optical_depth(chord_geometric_matrix, dtau):
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/unittests_long/transmission/transmission_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import jax
from jax.config import config
import pandas as pd
import numpy as np
import jax.numpy as jnp
from exojax.utils.grids import wavenumber_grid
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.atmrt import ArtTransPure
from exojax.utils.constants import RJ, Rs
from exojax.spec.api import MdbHitran
from exojax.utils.astrofunc import gravity_jupiter

from exojax.spec.unitconvert import wav2nu
from exojax.spec.specop import SopRotation
from exojax.spec.specop import SopInstProfile
from exojax.utils.instfunc import resolution_to_gaussian_std

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)


def read_data(filename):
dat = pd.read_csv(filename, delimiter=" ")
wav = dat["Wavelength[um]"]
mask = (wav > 2.25) & (wav < 2.6)
return wav[mask], dat["Rp/Rs"][mask]


# Read data
filename = "/home/kawahara/exojax/tests/integration/comparison/transmission/spectrum/CO100percent_500K.dat"
wav, rprs = read_data(filename)
inst_nus = wav2nu(np.array(wav), "um")

# Model
Nx = 3000
nu_grid, wav, res = wavenumber_grid(22900.0, 26000.0, Nx, unit="AA", xsmode="modit")

art = ArtTransPure(pressure_top=1.0e-15, pressure_btm=1.0e1, nlayer=100)
art.change_temperature_range(490.0, 510.0)

mdb = MdbHitran("CO", nu_grid, gpu_transfer=True, isotope=1)
opa = OpaPremodit(
mdb=mdb,
nu_grid=nu_grid,
auto_trange=[490, 510],
dit_grid_resolution=1,
)

sop_inst = SopInstProfile(nu_grid, res, vrmax=100.0)


def model(params):
mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV = params

Tarr = T_fid * np.ones_like(art.pressure)
mmr_arr = art.constant_mmr_profile(mmr_CO)

mmw = mu_fid * np.ones_like(art.pressure)
gravity = art.gravity_profile(Tarr, mmw, radius_btm, gravity_btm)

xsmatrix = opa.xsmatrix(Tarr, art.pressure)
dtau = art.opacity_profile_xs(xsmatrix, mmr_arr, opa.mdb.molmass, gravity)

Rp2 = art.run(dtau, Tarr, mmw, radius_btm, gravity_btm)

Rp2_sample = sop_inst.sampling(Rp2, RV, inst_nus)
return jnp.sqrt(Rp2_sample)


def objective(params):
return jnp.sum((np.array(rprs[::-1]) - model(params)) ** 2)


# Gradient
grad = jax.grad(objective)
params = np.array([1, 28.00863, 500, gravity_jupiter(1.0, 1.0), RJ, 0])
print()
print("Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV")
print("Gradient", grad(params))

0 comments on commit 974a489

Please sign in to comment.