Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Dec 20, 2024
1 parent f3308f2 commit 83f9a57
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
12 changes: 6 additions & 6 deletions examples/ablation_paths.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 204,
"execution_count": 209,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -87,7 +87,7 @@
"3 3 100 5.931323 10.0 5.0"
]
},
"execution_count": 204,
"execution_count": 209,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 205,
"execution_count": 210,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 206,
"execution_count": 211,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -223,7 +223,7 @@
},
{
"cell_type": "code",
"execution_count": 207,
"execution_count": 212,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -255,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 208,
"execution_count": 213,
"metadata": {},
"outputs": [
{
Expand Down
8 changes: 4 additions & 4 deletions hydra_plugins/hyper_analysis/ablation_path_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ def __init__(
performance_key="performance",
config_key="config_id",
variation_key="env",
run_source=True
run_source=True,
) -> None:
"""Initialize the optimizer."""
assert (data_path is not None and variation is not None) or (
source_config is not None and target_config is not None
), "Either data_path and variation or source_config and target_config must be provided."
if source_config is not None and target_config is not None:
self.source_config = configspace.get_default_configuration()
for k in source_config.keys():
for k in source_config:
self.source_config[k] = source_config[k]
self.target_config = configspace.get_default_configuration()
for k in target_config.keys():
for k in target_config:
self.target_config[k] = target_config[k]
else:
df = load_data(data_path, performance_key, config_key, variation_key) # noqa: PD901
Expand All @@ -60,7 +60,7 @@ def __init__(
self.returns = []
self.recompute_diffs = False
self.configs = self.get_configs()
self.configs = self.configs
self.configs = self.configs
if run_source:
self.configs += [self.source_config]

Expand Down
10 changes: 7 additions & 3 deletions hydra_plugins/hyper_analysis/grid_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import numpy as np
from ConfigSpace.hyperparameters import CategoricalHyperparameter

from hydra_plugins.hypersweeper import Info
from ConfigSpace.hyperparameters import CategoricalHyperparameter


class Grid:
Expand Down Expand Up @@ -34,7 +34,9 @@ def __init__(
for hp in configspace:
if isinstance(configspace[hp], CategoricalHyperparameter):
self.hp_values[hp] = np.linspace(0, len(configspace[hp].choices), configs_per_hp)
self.hp_values[hp] = [configspace[hp].choices[min(int(v), len(configspace[hp].choices)-1)] for v in self.hp_values[hp]]
self.hp_values[hp] = [
configspace[hp].choices[min(int(v), len(configspace[hp].choices) - 1)] for v in self.hp_values[hp]
]
else:
self.hp_values[hp] = np.linspace(configspace[hp].lower, configspace[hp].upper, configs_per_hp)
print(f"HP values in grid: {self.hp_values}")
Expand All @@ -43,7 +45,9 @@ def __init__(
def reset_indices(self, i):
"""Increment last index and pass on overflow to previous one."""
self.config_indices[list(self.config_indices.keys())[i]] += 1
if self.config_indices[list(self.config_indices.keys())[i]] >= len(self.hp_values[list(self.config_indices.keys())[i]]):
if self.config_indices[list(self.config_indices.keys())[i]] >= len(
self.hp_values[list(self.config_indices.keys())[i]]
):
self.config_indices[list(self.config_indices.keys())[i]] = 0
try:
self.reset_indices(i - 1)
Expand Down
1 change: 0 additions & 1 deletion hydra_plugins/hypersweeper/hypersweeper_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def run(self, verbose=False):
t = 0
terminate = False
while t < self.max_parallel and not terminate and not trial_termination and not budget_termination:

try:
info, terminate = self.optimizer.ask()
except Exception as e: # noqa: BLE001
Expand Down

0 comments on commit 83f9a57

Please sign in to comment.