Skip to content

Commit

Permalink
add backups for the fitparameters during iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehrhahn committed Jul 8, 2021
1 parent 185583c commit 0bcf6fc
Showing 1 changed file with 65 additions and 38 deletions.
103 changes: 65 additions & 38 deletions src/pysme/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,28 @@
And also determines the best fit parameters
"""

import json
import logging
import warnings
from os.path import splitext
import contextlib
import sys
import builtins

import numpy as np
from tqdm import tqdm
from scipy.constants import speed_of_light
from scipy.optimize import OptimizeWarning, least_squares, curve_fit
from scipy.optimize import OptimizeWarning, least_squares
from scipy.optimize._numdiff import approx_derivative
from scipy.special import erf
from scipy.stats import gennorm, norm
from scipy.stats import norm
from tqdm import tqdm

from . import __file_ending__, broadening
from . import __file_ending__
from .abund import Abund
from .atmosphere.atmosphere import AtmosphereError
from .atmosphere.savfile import SavFile
from .atmosphere.krzfile import KrzFile
from .config import Config
from .continuum_and_radial_velocity import match_rv_continuum
from .atmosphere.savfile import SavFile
from .large_file_storage import setup_lfs
from .iliffe_vector import Iliffe_vector
from .nlte import DirectAccessFile
from .sme_synth import SME_DLL
from .uncertainties import uncertainties
from .util import print_to_log
from .synthesize import Synthesizer
from .nlte import DirectAccessFile
from .util import print_to_log

logger = logging.getLogger(__name__)

Expand All @@ -40,7 +33,7 @@


class SME_Solver:
def __init__(self, filename=None):
def __init__(self, filename=None, restore=False):
self.dll = SME_DLL()
self.config, self.lfs_atmo, self.lfs_nlte = setup_lfs()
self.synthesizer = Synthesizer(
Expand All @@ -56,16 +49,46 @@ def __init__(self, filename=None):
self.parameter_names = []
self.update_linelist = False
self._latest_residual = None
self.restore = restore

# For displaying the progressbars
self.fig = None
self.progressbar = None
self.progressbar_jacobian = None

@property
def nparam(self):
return len(self.parameter_names)

def restore_func(self, sme):
fname = self.filename.rsplit(".", 1)[0]
fname = f"{fname}_iter.json"
try:
with open(fname) as f:
data = json.load(f)
# The keys are string, but we want the max in int, so we need to convert back and forth
iteration = str(max([int(i) for i in data.keys()]))
for fp in self.parameter_names:
sme[fp] = data[iteration].get(fp, sme[fp])
logger.warning(f"Restoring existing backup data from {fname}")
except:
pass
return sme

def backup(self, sme):
fname = self.filename.rsplit(".", 1)[0]
fname = f"{fname}_iter.json"
try:
with open(fname) as f:
data = json.load(f)
except:
data = {}
data[self.iteration] = {fp: sme[fp] for fp in self.parameter_names}
try:
with open(fname, "w") as f:
json.dump(data, f)
except:
pass

def __residuals(
self, param, sme, spec, uncs, mask, segments="all", isJacobian=False, **_
):
Expand Down Expand Up @@ -104,7 +127,7 @@ def __residuals(
residuals of the synthetic spectrum
"""
update = not isJacobian
save = not isJacobian
save = not isJacobian and self.filename is not None
reuse_wavelength_grid = isJacobian
radial_velocity_mode = "robust" if not isJacobian else "fast"

Expand All @@ -128,15 +151,6 @@ def __residuals(
logger.debug(ae)
return np.inf

# Also save intermediary results, because we can
if save and self.filename is not None:
if self.filename.endswith(__file_ending__):
fname = self.filename[:-4]
else:
fname = self.filename
fname = f"{fname}_tmp{__file_ending__}"
sme.save(fname)

segments = Synthesizer.check_segments(sme, segments)

# Get the correct results for the comparison
Expand Down Expand Up @@ -168,15 +182,11 @@ def __residuals(
self._latest_residual = resid
self.iteration += 1
logger.debug("%s", {n: v for n, v in zip(self.parameter_names, param)})
# Plot
# if fig is not None:
# wave = sme2.wave
# try:
# fig.add(wave, synth, f"Iteration {self.iteration}")
# except AttributeError:
# warnings.warn(f"Figure {repr(fig)} doesn't have a 'add' method")
# except Exception as e:
# warnings.warn(f"Error during Plotting: {e.message}")
# Store progress (async)

# Also save intermediary results, because we can
if save:
self.backup(sme)

return resid

Expand Down Expand Up @@ -561,6 +571,18 @@ def solve(self, sme, param_names=None, segments="all", bounds=None):
assert "wave" in sme, "SME Structure has no wavelength"
assert "spec" in sme, "SME Structure has no observation"

if self.restore and self.filename is not None:
fname = self.filename.rsplit(".", 1)[0]
fname = f"{fname}_iter.json"
try:
with open(fname) as f:
data = json.load(f)
for fp in param_names:
sme[fp] = data[fp]
logger.warning(f"Restoring existing backup data from {fname}")
except:
pass

if "uncs" not in sme:
sme.uncs = np.ones(sme.spec.size)
logger.warning("SME Structure has no uncertainties, using all ones instead")
Expand Down Expand Up @@ -596,6 +618,9 @@ def solve(self, sme, param_names=None, segments="all", bounds=None):
"Initial values are incompatible with the bounds, clipping initial values"
)
p0 = np.clip(p0, bounds[0], bounds[1])
# Restore backup
if self.restore:
sme = self.restore_func(sme)

# Get constant data from sme structure
sme.mask[segments][sme.uncs[segments] == 0] = sme.mask_values["bad"]
Expand Down Expand Up @@ -674,6 +699,8 @@ def solve(self, sme, param_names=None, segments="all", bounds=None):
return sme


def solve(sme, param_names=None, segments="all", filename=None, **kwargs):
solver = SME_Solver(filename=filename)
def solve(
sme, param_names=None, segments="all", filename=None, restore=False, **kwargs
):
solver = SME_Solver(filename=filename, restore=restore)
return solver.solve(sme, param_names, segments, **kwargs)

0 comments on commit 0bcf6fc

Please sign in to comment.