Skip to content

Commit

Permalink
add "spline" option to continuum
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehrhahn committed Jul 7, 2021
1 parent 96c2bd5 commit 185583c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
38 changes: 27 additions & 11 deletions src/pysme/continuum_and_radial_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from scipy.optimize import least_squares, minimize_scalar, curve_fit
from scipy.signal import correlate, find_peaks
from scipy.interpolate import UnivariateSpline
from scipy.interpolate import splrep, splev

from tqdm import tqdm

from . import util
Expand Down Expand Up @@ -42,8 +44,8 @@ def apply_continuum(wave, smod, cscale, cscale_type, segments):
return smod
for il in segments:
if cscale[il] is not None and not np.all(cscale[il] == 0):
if cscale_type in ["smooth"]:
smod[il] += cscale[il]
if cscale_type in ["spline"]:
smod[il] *= cscale[il]
else:
x = wave[il] - wave[il][0]
smod[il] *= np.polyval(cscale[il], x)
Expand Down Expand Up @@ -368,8 +370,8 @@ def func(rv):

def null_result(nseg, ndeg=0, ctype=None):
vrad, vrad_unc = np.zeros(nseg), np.zeros((nseg, 2))
if ctype in ["smooth"]:
cscale = [np.zeros(ndeg[i]) for i in range(nseg)]
if ctype in ["spline"]:
cscale = [np.ones(ndeg[i]) for i in range(nseg)]
cscale = Iliffe_vector(values=cscale)
cscale_unc = [np.zeros(ndeg[i]) for i in range(nseg)]
cscale_unc = Iliffe_vector(values=cscale_unc)
Expand Down Expand Up @@ -756,25 +758,39 @@ def get_continuum_broadening(sme, segment, x_syn, y_syn, rvel=0, only_mask=False

w = sme.wave[segment]
s = sme.spec[segment]
u = sme.uncs[segment]

# Apply RV correction to the synthetic spectrum
rv_factor = np.sqrt((1 - rvel / c_light) / (1 + rvel / c_light))
wp = w * rv_factor
y = np.interp(wp, x_syn, y_syn)

# and don't forget the telluric spectrum if available
if sme.telluric is not None:
tell = sme.telluric[segment]
y *= tell

# Apply the bpm to all arrays
if only_mask:
m = sme.mask_cont[segment]
else:
m = sme.mask_good[segment]

wm, sm, ym = w[m], s[m], y[m]
wm, sm, ym, um = w[m], s[m], y[m], u[m]

# Fit the spline like in the polynomial
# so that synth * spline = obs
func = lambda p: ym * splev(wm, (p[:l1], p[l1:], 3)) - sm
# We use splrep to find the intial guess for the number of knots and their
# positions
t, c, k = splrep(wm, sm, k=3, w=1 / um, s=len(wm))
l1, l2 = len(t), len(c)
res = least_squares(func, x0=[*t, *c])
t, c = res.x[:l1], res.x[l1:]
coef = splev(w, (t, c, k))

sf = UnivariateSpline(wm, sm, w=1 / np.sqrt(sm))(w)
yf = UnivariateSpline(wm, ym, w=1 / np.sqrt(ym))(w)
coef = -yf + sf
# sf = UnivariateSpline(wm, sm, w=1 / np.sqrt(sm))(w)
# yf = UnivariateSpline(wm, ym, w=1 / np.sqrt(ym))(w)
# coef = -yf + sf

return coef

Expand Down Expand Up @@ -864,7 +880,7 @@ def match_rv_continuum(sme, segments, x_syn, y_syn):
vrad[s] = determine_radial_velocity(
sme, s, cscale[s], [x_syn[s] for s in s], [y_syn[s] for s in s]
)
elif sme.cscale_type in ["smooth"]:
elif sme.cscale_type in ["spline"]:
for s in segments:
# We only use the continuum mask for the continuum fit,
# we need the lines for the radial velocity
Expand All @@ -891,7 +907,7 @@ def match_rv_continuum(sme, segments, x_syn, y_syn):
for seg in segments:
mask &= select != seg
vrad[mask] = sme.vrad[mask]
if sme.cscale_type == "smooth":
if sme.cscale_type in ["spline"]:
for i in range(len(mask)):
if mask[i] and len(cscale[i]) == len(sme.cscale[i]):
cscale[i] = sme.cscale[i]
Expand Down
14 changes: 9 additions & 5 deletions src/pysme/sme.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class SME_Structure(Parameters):
* "linear": First order polynomial, i.e. approximate continuum by a straight line
* "quadratic": Second order polynomial, i.e. approximate continuum by a quadratic polynomial
"""),
("cscale_type", "match+mask", lowercase(oneof("mcmc", "mask", "match", "match+mask", "smooth")), this,
("cscale_type", "match+mask", lowercase(oneof("mcmc", "mask", "match", "match+mask", "spline")), this,
"""str: Flag that determines the algorithm to determine the continuum
This is used in combination with cscale_flag, which determines the degree of the fit, if any.
Expand Down Expand Up @@ -409,7 +409,7 @@ def _cscale(self):
The x coordinates of each polynomial are chosen so that x = 0, at the first wavelength point,
i.e. x is shifted by wave[segment][0]
"""
if self.cscale_type == "smooth":
if self.cscale_type == "spline":
return self.__cscale

nseg = self.nseg if self.nseg is not None else 1
Expand Down Expand Up @@ -448,7 +448,7 @@ def _cscale(self):

@_cscale.setter
def _cscale(self, value):
if self.cscale_type == "smooth":
if self.cscale_type in ["spline"]:
if not isinstance(value, Iliffe_vector):
self.__cscale = (
Iliffe_vector(values=value) if value is not None else None
Expand Down Expand Up @@ -558,7 +558,7 @@ def mask_cont(self):
@property
def cscale_degree(self):
"""int: Polynomial degree of the continuum as determined by cscale_flag """
if self.cscale_type in ["smooth"]:
if self.cscale_type in ["spline"]:
return [np.count_nonzero(mg) for mg in self.mask_good]
else:
if self.cscale_flag == "constant":
Expand Down Expand Up @@ -623,7 +623,7 @@ def __convert_cscale__(self):
elif self.cscale_flag == "constant":
self.cscale = np.sqrt(1 / self.cscale)

def import_mask(self, other):
def import_mask(self, other, keep_bpm=False):
"""
Import the mask of another sme structure and apply it to this one
Conversion is based on the wavelength
Expand Down Expand Up @@ -660,6 +660,10 @@ def import_mask(self, other):
w = self.wave[seg] * rv_factor
cm = np.interp(w, wave, cont_mask) > 0.5
lm = np.interp(w, wave, line_mask) > 0.5
if keep_bpm:
bpm = self.mask_bad[seg]
cm[bpm] = False
lm[bpm] = False
self.mask[seg][cm] = self.mask_values["continuum"]
self.mask[seg][lm] = self.mask_values["line"]
self.mask[seg][~(cm | lm)] = self.mask_values["bad"]
Expand Down
2 changes: 1 addition & 1 deletion src/pysme/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def synthesize_spectrum(
sme.synth[s] = smod[s]
sme.cont[s] = cmod[s]

if sme.cscale_type == "smooth":
if sme.cscale_type == "spline":
sme.cscale = cscale
sme.cscale_unc = cscale_unc
elif sme.cscale_flag not in ["fix", "none"]:
Expand Down

0 comments on commit 185583c

Please sign in to comment.