diff --git a/src/pysme/solve.py b/src/pysme/solve.py index 75c8951c..c0374a4c 100644 --- a/src/pysme/solve.py +++ b/src/pysme/solve.py @@ -3,35 +3,28 @@ And also determines the best fit parameters """ +import json import logging import warnings from os.path import splitext -import contextlib -import sys -import builtins import numpy as np -from tqdm import tqdm from scipy.constants import speed_of_light -from scipy.optimize import OptimizeWarning, least_squares, curve_fit +from scipy.optimize import OptimizeWarning, least_squares from scipy.optimize._numdiff import approx_derivative -from scipy.special import erf -from scipy.stats import gennorm, norm +from scipy.stats import norm +from tqdm import tqdm -from . import __file_ending__, broadening +from . import __file_ending__ from .abund import Abund from .atmosphere.atmosphere import AtmosphereError -from .atmosphere.savfile import SavFile from .atmosphere.krzfile import KrzFile -from .config import Config -from .continuum_and_radial_velocity import match_rv_continuum +from .atmosphere.savfile import SavFile from .large_file_storage import setup_lfs -from .iliffe_vector import Iliffe_vector +from .nlte import DirectAccessFile from .sme_synth import SME_DLL -from .uncertainties import uncertainties -from .util import print_to_log from .synthesize import Synthesizer -from .nlte import DirectAccessFile +from .util import print_to_log logger = logging.getLogger(__name__) @@ -40,7 +33,7 @@ class SME_Solver: - def __init__(self, filename=None): + def __init__(self, filename=None, restore=False): self.dll = SME_DLL() self.config, self.lfs_atmo, self.lfs_nlte = setup_lfs() self.synthesizer = Synthesizer( @@ -56,9 +49,9 @@ def __init__(self, filename=None): self.parameter_names = [] self.update_linelist = False self._latest_residual = None + self.restore = restore # For displaying the progressbars - self.fig = None self.progressbar = None self.progressbar_jacobian = None @@ -66,6 +59,36 @@ def __init__(self, filename=None): def nparam(self): return len(self.parameter_names) + def restore_func(self, sme): + fname = self.filename.rsplit(".", 1)[0] + fname = f"{fname}_iter.json" + try: + with open(fname) as f: + data = json.load(f) + # The keys are string, but we want the max in int, so we need to convert back and forth + iteration = str(max([int(i) for i in data.keys()])) + for fp in self.parameter_names: + sme[fp] = data[iteration].get(fp, sme[fp]) + logger.warning(f"Restoring existing backup data from {fname}") + except: + pass + return sme + + def backup(self, sme): + fname = self.filename.rsplit(".", 1)[0] + fname = f"{fname}_iter.json" + try: + with open(fname) as f: + data = json.load(f) + except: + data = {} + data[self.iteration] = {fp: sme[fp] for fp in self.parameter_names} + try: + with open(fname, "w") as f: + json.dump(data, f) + except: + pass + def __residuals( self, param, sme, spec, uncs, mask, segments="all", isJacobian=False, **_ ): @@ -104,7 +127,7 @@ def __residuals( residuals of the synthetic spectrum """ update = not isJacobian - save = not isJacobian + save = not isJacobian and self.filename is not None reuse_wavelength_grid = isJacobian radial_velocity_mode = "robust" if not isJacobian else "fast" @@ -128,15 +151,6 @@ def __residuals( logger.debug(ae) return np.inf - # Also save intermediary results, because we can - if save and self.filename is not None: - if self.filename.endswith(__file_ending__): - fname = self.filename[:-4] - else: - fname = self.filename - fname = f"{fname}_tmp{__file_ending__}" - sme.save(fname) - segments = Synthesizer.check_segments(sme, segments) # Get the correct results for the comparison @@ -168,15 +182,11 @@ def __residuals( self._latest_residual = resid self.iteration += 1 logger.debug("%s", {n: v for n, v in zip(self.parameter_names, param)}) - # Plot - # if fig is not None: - # wave = sme2.wave - # try: - # fig.add(wave, synth, f"Iteration {self.iteration}") - # except AttributeError: - # warnings.warn(f"Figure {repr(fig)} doesn't have a 'add' method") - # except Exception as e: - # warnings.warn(f"Error during Plotting: {e.message}") + # Store progress (async) + + # Also save intermediary results, because we can + if save: + self.backup(sme) return resid @@ -561,6 +571,18 @@ def solve(self, sme, param_names=None, segments="all", bounds=None): assert "wave" in sme, "SME Structure has no wavelength" assert "spec" in sme, "SME Structure has no observation" + if self.restore and self.filename is not None: + fname = self.filename.rsplit(".", 1)[0] + fname = f"{fname}_iter.json" + try: + with open(fname) as f: + data = json.load(f) + for fp in param_names: + sme[fp] = data[fp] + logger.warning(f"Restoring existing backup data from {fname}") + except: + pass + if "uncs" not in sme: sme.uncs = np.ones(sme.spec.size) logger.warning("SME Structure has no uncertainties, using all ones instead") @@ -596,6 +618,9 @@ def solve(self, sme, param_names=None, segments="all", bounds=None): "Initial values are incompatible with the bounds, clipping initial values" ) p0 = np.clip(p0, bounds[0], bounds[1]) + # Restore backup + if self.restore: + sme = self.restore_func(sme) # Get constant data from sme structure sme.mask[segments][sme.uncs[segments] == 0] = sme.mask_values["bad"] @@ -674,6 +699,8 @@ def solve(self, sme, param_names=None, segments="all", bounds=None): return sme -def solve(sme, param_names=None, segments="all", filename=None, **kwargs): - solver = SME_Solver(filename=filename) +def solve( + sme, param_names=None, segments="all", filename=None, restore=False, **kwargs +): + solver = SME_Solver(filename=filename, restore=restore) return solver.solve(sme, param_names, segments, **kwargs)