Skip to content

Commit

Permalink
Update search_cv.py
Browse files Browse the repository at this point in the history
en results_ property, was added keys "best_model" and "models"
  • Loading branch information
AlexPascau authored May 18, 2024
1 parent a9095e8 commit 8b90162
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorcross/model_selection/search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __init__(
self.results_: ResultsDict = {
"best_score": -np.inf,
"best_params": {},
"best_model" : [],
"val_scores": [],
"params": [],
"models" : []
}

def _run_search( # noqa: PLR0912
Expand Down Expand Up @@ -136,6 +138,7 @@ def _run_search( # noqa: PLR0912

self.results_["val_scores"].append(val_scores)
self.results_["params"].append(grid_combination)
self.results_["models"].append(model)

logger.setLevel(tf_log_level) # Issue 30

Expand All @@ -146,6 +149,7 @@ def _run_search( # noqa: PLR0912
best_run_idx = np.argmin(mean_val_scores)
self.results_["best_score"] = self.results_["val_scores"][best_run_idx]
self.results_["best_params"] = self.results_["params"][best_run_idx]
self.results_["best_model"] = self.results_["models"][best_run_idx]

def summary(self) -> str:
"""Prints the summary of the search to the console.
Expand Down

0 comments on commit 8b90162

Please sign in to comment.