Skip to content

Commit

Permalink
Merge pull request #176 from automl/load_labeled_fix
Browse files Browse the repository at this point in the history
Fix for loading labeled architecture in optimizers
  • Loading branch information
gurizab authored Sep 21, 2023
2 parents dfa2e67 + 927efe5 commit fc40ee5
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
5 changes: 4 additions & 1 deletion naslib/optimizers/discrete/bananas/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __init__(self, config, zc_api=None):
config.search, 'zc_names') else None
self.zc_only = config.search.zc_only if hasattr(
config.search, 'zc_only') else False

self.load_labeled = config.search.load_labeled if hasattr(
config.search, 'load_labeled') else False

def adapt_search_space(self, search_space, scope=None, dataset_api=None):
assert (
Expand Down Expand Up @@ -119,7 +122,7 @@ def _sample_new_model(self):
model = torch.nn.Module()
model.arch = self.search_space.clone()
model.arch.sample_random_architecture(
dataset_api=self.dataset_api, load_labeled=self.use_zc_api)
dataset_api=self.dataset_api, load_labeled=self.load_labeled)
model.arch_hash = model.arch.get_hash()

if self.search_space.instantiate_model == True:
Expand Down
4 changes: 3 additions & 1 deletion naslib/optimizers/discrete/npenas/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, config, zc_api=None):
config.search, 'zc_names') else None
self.zc_only = config.search.zc_only if hasattr(
config.search, 'zc_only') else False
self.load_labeled = config.search.load_labeled if hasattr(
config.search, 'load_labeled') else False

def adapt_search_space(self, search_space, scope=None, dataset_api=None):
assert (
Expand Down Expand Up @@ -119,7 +121,7 @@ def _sample_new_model(self):
model = torch.nn.Module()
model.arch = self.search_space.clone()
model.arch.sample_random_architecture(
dataset_api=self.dataset_api, load_labeled=self.use_zc_api)
dataset_api=self.dataset_api, load_labeled=self.load_labeled)
model.arch_hash = model.arch.get_hash()

if self.search_space.instantiate_model == True:
Expand Down
77 changes: 60 additions & 17 deletions naslib/runners/zc/zc_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,66 @@ cutout_prob: 1.0
dataset: cifar10
out_dir: run
predictor: fisher
search_space: nasbench101 #nasbench201 #nasbench301
seed: 0
search_space: nasbench201 #nasbench101 #nasbench301
test_size: 200
train_size: 400
zc_ensemble: true
zc_only: true
zc_names:
- params
- flops
- jacov
- plain
- grasp
- snip
- fisher
- grad_norm
- epe_nas
- synflow
- l2_norm
optimizer: npenas
train_portion: 0.7
train_portion: 0.7
seed: 0

search:
# for bohb
seed: 0
budgets: 50000000
checkpoint_freq: 1000
fidelity: 108

# for all optimizers
epochs: 10

# for bananas and npenas, choose one predictor
# out of the 16 model-based predictors
predictor_type: var_sparse_gp

# number of initial architectures
num_init: 10

# NPENAS
k: 10
num_ensemble: 3
acq_fn_type: its
acq_fn_optimization: mutation
encoding_type: adjacency_one_hot
num_arches_to_mutate: 1
max_mutations: 1
num_candidates: 50

# jacov data loader
batch_size: 256
data_size: 25000
cutout: False
cutout_length: 16
cutout_prob: 1.0
train_portion: 0.7

# other params
debug_predictor: False
sample_size: 10
population_size: 30

# zc parameters
use_zc_api: False
zc_ensemble: true
zc_names:
- params
- flops
- jacov
- plain
- grasp
- snip
- fisher
- grad_norm
- epe_nas
- synflow
- l2_norm
zc_only: true

0 comments on commit fc40ee5

Please sign in to comment.