Skip to content

Commit

Permalink
add least squares fit parameters to the radial velocity and continuum…
Browse files Browse the repository at this point in the history
… fits
  • Loading branch information
AWehrhahn committed Oct 12, 2022
1 parent d3e173f commit 7300efd
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 31 deletions.
145 changes: 121 additions & 24 deletions src/pysme/continuum_and_radial_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def __init__(self):
self.cscale_type = "none"
pass

def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
def __call__(self, sme, x_syn, y_syn, segments, rvel=0, least_squares_kwargs=None):
if least_squares_kwargs is None:
least_squares_kwargs = {}
raise NotImplementedError

def apply(self, wave, smod, cwave, cscale, segments):
Expand All @@ -45,7 +47,15 @@ def __init__(self):
super().__init__()
self.cscale_tyoe = "mask"

def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
def __call__(
self,
sme,
x_syn,
y_syn,
segments,
rvel=0,
least_squares_kwargs=None,
):
"""
Fit a polynomial to the spectrum points marked as continuum
The degree of the polynomial fit is determined by sme.cscale_flag
Expand All @@ -66,6 +76,9 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
if segments < 0:
return sme.cscale

if least_squares_kwargs is None:
least_squares_kwargs = {}

if "spec" not in sme or "wave" not in sme:
# If there is no observation, we have no continuum scale
warnings.warn("Missing data for continuum fit")
Expand Down Expand Up @@ -117,7 +130,7 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
try:
func = lambda coef: (np.polyval(coef, x) - y) / u
c0 = np.polyfit(x, y, deg=ndeg)
res = least_squares(func, x0=c0)
res = least_squares(func, x0=c0, **least_squares_kwargs)
cscale = res.x
except TypeError:
warnings.warn("Could not fit continuum, set continuum mask?")
Expand Down Expand Up @@ -188,7 +201,15 @@ def __init__(self):
super().__init__()
self.cscale_tyoe = "mcmc"

def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
def __call__(
self,
sme,
x_syn,
y_syn,
segments,
rvel=0,
least_squares_kwargs=None,
):
"""
Fits both radial velocity and continuum level simultaneously
by comparing the synthetic spectrum to the observation
Expand Down Expand Up @@ -223,6 +244,9 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
segments = [segments]
nseg = len(segments)

if least_squares_kwargs is None:
least_squares_kwargs = {}

if sme.cscale_flag in ["none", "fix"] and sme.vrad_flag in [
"none",
"fix",
Expand Down Expand Up @@ -514,7 +538,15 @@ def __init__(self):
self.top_factor = 500_000
self.bottom_factor = 1

def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
def __call__(
self,
sme,
x_syn,
y_syn,
segments,
rvel=0,
least_squares_kwargs=None,
):
"""
Fit a continuum when no continuum points exist
Expand All @@ -541,6 +573,10 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel=0):
return [1]
elif sme.cscale_flag == "fix":
return sme.cscale[segments]

if least_squares_kwargs is None:
least_squares_kwargs = {}

# else
x = sme.wave[segments]
y = sme.spec[segments]
Expand Down Expand Up @@ -580,7 +616,11 @@ def func(p):
return resid

try:
res = least_squares(func, x0=p0, method="lm", x_scale="jac")
res = least_squares(
func,
x0=p0,
**least_squares_kwargs,
)
popt = res.x
except (RuntimeError, ValueError) as ex:
logger.warning("Failed to determine the continuum: " + ex.msg)
Expand Down Expand Up @@ -618,7 +658,15 @@ def __init__(self):
self.cscale_type = "spline"
self.mask = False

def __call__(self, sme, x_syn, y_syn, segments, rvel):
def __call__(
self,
sme,
x_syn,
y_syn,
segments,
rvel,
least_squares_kwargs=None,
):
if sme.cscale_flag in ["none"]:
return np.ones(len(sme.spec[segments]))
elif sme.cscale_flag in ["fix"]:
Expand All @@ -627,6 +675,10 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel):
else:
return np.ones(len(sme.spec[segments]))

if least_squares_kwargs is None:
least_squares_kwargs = {}
least_squares_kwargs.setdefault("f_scale", 0.01)

w = sme.wave[segments]
s = sme.spec[segments]
u = sme.uncs[segments]
Expand Down Expand Up @@ -672,7 +724,11 @@ def __call__(self, sme, x_syn, y_syn, segments, rvel):
t = np.linspace(wm[m2].min(), wm[m2].max(), tlen)[1:-1]
t, c, k = splrep(wm[m2], sm[m2] / ym[m2], w=1 / um[m2], k=3, t=t)
# Then get a real fit using the function
res = least_squares(func, x0=c, loss="soft_l1", method="trf", f_scale=0.01)
res = least_squares(
func,
x0=c,
**least_squares_kwargs,
)
# And finally evaluate the continuum
c = res.x
coef = splev(w, (t, c, k))
Expand Down Expand Up @@ -806,8 +862,8 @@ def determine_radial_velocity(
y_syn,
segment,
cscale=None,
rv_bounds=(-100, 100),
whole=False,
least_squares_kwargs=None,
):
"""
Calculate radial velocity by using cross correlation and
Expand Down Expand Up @@ -838,6 +894,17 @@ def determine_radial_velocity(
or None if no observation is present
"""

if least_squares_kwargs is None:
least_squares_kwargs = {}
least_squares_kwargs.setdefault("bounds", (-100, 100))
least_squares_kwargs.setdefault("jac", "3-point")
least_squares_kwargs.setdefault("loss", "soft_l1")
least_squares_kwargs.setdefault("method", "dogbox")
least_squares_kwargs.setdefault("x_scale", "jac")
least_squares_kwargs.setdefault("ftol", 1e-8)
least_squares_kwargs.setdefault("xtol", 1e-8)
least_squares_kwargs.setdefault("gtol", 1e-8)

if "spec" not in sme or "wave" not in sme:
# No observation no radial velocity
warnings.warn("Missing data for radial velocity determination")
Expand Down Expand Up @@ -912,7 +979,8 @@ def determine_radial_velocity(
mask &= u_obs != 0

# Widen the mask by roughly the amount expected from the rv_bounds
rv = max(rv_bounds)
bounds = least_squares_kwargs["bounds"]
rv = max(bounds)
rv_factor = np.sqrt((1 + rv / c_light) / (1 - rv / c_light))
# mask_wider = mask.copy()
if sme.vrad_flag == "each":
Expand Down Expand Up @@ -942,7 +1010,7 @@ def determine_radial_velocity(
# Get a first rough estimate from cross correlation
if sme.vrad_flag == "each":
x_shift, corr = cross_correlate_segment(
x_obs, y_obs, x_syn, y_syn, mask, mask_wider, rv_bounds
x_obs, y_obs, x_syn, y_syn, mask, mask_wider, bounds
)
else:
# If using several segments we run the cross correlation for each
Expand All @@ -958,11 +1026,11 @@ def determine_radial_velocity(
y_syn[i],
mask[i],
mask_wider[i],
rv_bounds,
bounds,
)

n_min = min(len(s) for s in shift)
x_shift = np.linspace(rv_bounds[0], rv_bounds[1], n_min)
x_shift = np.linspace(bounds[0], bounds[1], n_min)
corrs_interp = [np.interp(x_shift, s, c) for s, c in zip(shift, corr)]
corrs_interp = np.array(corrs_interp)
corr = np.sum(corrs_interp, axis=0)
Expand Down Expand Up @@ -991,25 +1059,24 @@ def determine_radial_velocity(
# Apply mask
y_obs[~mask] = 0
y_syn[~mask_wider] = 0
u_obs = u_obs.copy()
u_obs[~mask] = 1

# Then minimize the least squares for a better fit
# as cross correlation can only find
def func(rv):
rv_factor = np.sqrt((1 - rv / c_light) / (1 + rv / c_light))
shifted = interpolator(x_obs * rv_factor)
resid = (y_obs - shifted * tell) / u_obs
resid = np.nan_to_num(resid, copy=False)
resid = np.nan_to_num(resid, copy=False, nan=0, posinf=0, neginf=0)
return resid

interpolator = lambda x: np.interp(x, x_syn, y_syn)
try:
res = least_squares(
func, x0=rvel, loss="soft_l1", bounds=rv_bounds, jac="3-point"
)
res = least_squares(func, x0=rvel, **least_squares_kwargs)
rvel = res.x[0]
except ValueError:
logger.warning(f"Could not determine radial velocity for segment {segment}")
rvel = 0
return rvel


Expand Down Expand Up @@ -1068,6 +1135,17 @@ def match_rv_continuum(sme, segments, x_syn, y_syn):
else:
continuum_normalization = sme.cscale_type

least_squares_kwargs = dict(
bounds=sme.vrad_bounds,
loss=sme.vrad_loss,
method=sme.vrad_method,
jac=sme.vrad_jac,
x_scale=sme.vrad_xscale,
ftol=sme.vrad_ftol,
xtol=sme.vrad_xtol,
gtol=sme.vrad_gtol,
)

if sme.vrad_flag == "none":
pass
elif sme.vrad_flag == "fix":
Expand All @@ -1076,26 +1154,25 @@ def match_rv_continuum(sme, segments, x_syn, y_syn):
for s in segments:
# We only use the continuum mask for the continuum fit,
# we need the lines for the radial velocity
rv_bounds = (-sme.vrad_limit, sme.vrad_limit)
vrad[s] = radial_velocity(
sme,
x_syn[s],
y_syn[s],
s,
cscale[s],
rv_bounds=rv_bounds,
whole=False,
least_squares_kwargs=least_squares_kwargs,
)
elif sme.vrad_flag == "whole":
s = segments
rv_bounds = (-sme.vrad_limit, sme.vrad_limit)
vrad[s] = radial_velocity(
sme,
[x_syn[s] for s in s],
[y_syn[s] for s in s],
s,
cscale[s],
whole=True,
rv_bounds=rv_bounds,
least_squares_kwargs=least_squares_kwargs,
)
else:
raise ValueError
Expand All @@ -1115,11 +1192,31 @@ def match_rv_continuum(sme, segments, x_syn, y_syn):
vrad_unc[s],
cscale[s],
cscale_unc[s],
) = continuum_normalization(sme, x_syn[s], y_syn[s], s, rvel=vrad[s])
) = continuum_normalization(
sme,
x_syn[s],
y_syn[s],
s,
rvel=vrad[s],
)
else:
for s in segments:
cscale[s] = continuum_normalization(
sme, x_syn[s], y_syn[s], s, rvel=vrad[s]
sme,
x_syn[s],
y_syn[s],
s,
rvel=vrad[s],
least_squares_kwargs=dict(
bounds=sme.cscale_bounds,
loss=sme.cscale_loss,
method=sme.cscale_method,
jac=sme.cscale_jac,
x_scale=sme.cscale_xscale,
ftol=sme.cscale_ftol,
xtol=sme.cscale_xtol,
gtol=sme.cscale_gtol,
),
)

# Keep values from unused segments
Expand Down
Loading

0 comments on commit 7300efd

Please sign in to comment.