Skip to content

Commit

Permalink
Moved warnings from gpytorch and botorch to the logfile
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikOrm committed Aug 29, 2023
1 parent 3da5453 commit 15c12f9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
43 changes: 31 additions & 12 deletions hypermapper/bo/models/gpbotorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def fit(
Returns:
- Hyperparameters of the model or None if the model is not fitted.
"""

warnings.filterwarnings(
"ignore", category=gpytorch.utils.warnings.GPInputWarning
)

mll = ExactMarginalLogLikelihood(self.likelihood, self)
if settings["multistart_hyperparameter_optimization"]:
worst_log_likelihood = np.inf
Expand Down Expand Up @@ -107,10 +112,11 @@ def fit(
sample_point[3],
)

warnings.filterwarnings(
"ignore", category=gpytorch.utils.warnings.GPInputWarning
)
fit_gpytorch_mll(mll)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fit_gpytorch_mll(mll)
for warning in w:
sys.stdout.write_to_logfile(f"{str(warning.message)}\n")
self.train(), self.likelihood.train()

mll_val = mll(self(*self.train_inputs), self.train_targets)
Expand All @@ -126,8 +132,8 @@ def fit(
if mll_val < worst_log_likelihood:
worst_log_likelihood = mll_val
except Exception as e:
print(f"Warning: failed to fit in iteration {i}")
print(e)
sys.stdout.write_to_logfile(f"Warning: failed to fit in iteration {i}\n")
sys.stdout.write_to_logfile(f"{e}\n")

if best_GP is None:
sys.stdout.write_to_logfile(
Expand All @@ -148,10 +154,14 @@ def fit(
else:
mll = ExactMarginalLogLikelihood(self.likelihood, self)
try:
fit_gpytorch_mll(mll)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fit_gpytorch_mll(mll)
for warning in w:
sys.stdout.write_to_logfile(f"{str(warning.message)}\n")
except Exception as e:
print("Warning: Failed to fit model.")
print(e)
sys.stdout.write_to_logfile("Warning: Failed to fit model.\n")
sys.stdout.write_to_logfile(f"{e}\n")
self._backup_fit(mll)

sys.stdout.write_to_logfile(
Expand Down Expand Up @@ -248,12 +258,21 @@ def fit(
- settings:
- previous_hyperparameters: hyperparameters from previous iterations
"""

warnings.filterwarnings(
"ignore", category=gpytorch.utils.warnings.GPInputWarning
)

mll = ExactMarginalLogLikelihood(self.likelihood, self)
try:
fit_gpytorch_mll(mll)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fit_gpytorch_mll(mll)
for warning in w:
sys.stdout.write_to_logfile(f"{str(warning.message)}\n")
except Exception as e:
print("Warning: Failed to fit model.")
print(e)
sys.stdout.write_to_logfile("Warning: Failed to fit model.\n")
sys.stdout.write_to_logfile(f"{e}\n")
self._backup_fit(mll)

sys.stdout.write_to_logfile(
Expand Down
29 changes: 19 additions & 10 deletions hypermapper/bo/models/gpgpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def fit(
Returns:
- Hyperparameters of the model or None if the model is not fitted.
"""
warnings.filterwarnings(
"ignore", category=gpytorch.utils.warnings.GPInputWarning
)

mll = ExactMarginalLogLikelihood(self.likelihood, self)
if settings["multistart_hyperparameter_optimization"]:
worst_log_likelihood = np.inf
Expand Down Expand Up @@ -197,10 +201,11 @@ def fit(
sample_point[3],
)

warnings.filterwarnings(
"ignore", category=gpytorch.utils.warnings.GPInputWarning
)
fit_gpytorch_mll(mll)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fit_gpytorch_mll(mll)
for warning in w:
sys.stdout.write_to_logfile(f"{str(warning.message)}\n")
self.train(), self.likelihood.train()

mll_val = mll(self(*self.train_inputs), self.train_targets)
Expand All @@ -216,11 +221,11 @@ def fit(
if mll_val < worst_log_likelihood:
worst_log_likelihood = mll_val
except Exception as e:
print(f"Warning: failed to fit in iteration {i}")
print(e)
sys.stdout.write_to_logfile(f"Warning: failed to fit in iteration {i}\n"
f"{e}\n")

if best_GP is None:
print(
sys.stdout.write_to_logfile(
f"Failed to fit the GP hyperparameters in all of the {settings['multistart_hyperparameter_optimization_iterations']} iterations."
)
sys.stdout.write_to_logfile(
Expand All @@ -241,10 +246,14 @@ def fit(
else:
mll = ExactMarginalLogLikelihood(self.likelihood, self)
try:
fit_gpytorch_mll(mll)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fit_gpytorch_mll(mll)
for warning in w:
sys.stdout.write_to_logfile(f"{str(warning.message)}\n")
except Exception as e:
print(f"Warning: failed to fit model.")
print(e)
sys.stdout.write_to_logfile(f"Warning: failed to fit model.\n")
sys.stdout.write_to_logfile(f"{e}\n")
self._backup_fit(mll)

sys.stdout.write_to_logfile(
Expand Down

0 comments on commit 15c12f9

Please sign in to comment.