-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c251054
commit 974609b
Showing
12 changed files
with
2,204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
import numpy as np | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
|
||
|
||
def _concat(xs): | ||
return torch.cat([x.view(-1) for x in xs]) | ||
|
||
|
||
class Architect(object): | ||
|
||
def __init__(self, model, args): | ||
self.network_momentum = args.momentum | ||
self.network_weight_decay = args.weight_decay | ||
self.model = model | ||
self.optimizer = torch.optim.Adam(self.model.arch_parameters(), | ||
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) | ||
|
||
def _compute_unrolled_model(self, input, target, eta, network_optimizer): | ||
loss = self.model._loss(input, target) | ||
theta = _concat(self.model.parameters()).data | ||
try: | ||
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum) | ||
except: | ||
moment = torch.zeros_like(theta) | ||
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta | ||
unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) | ||
return unrolled_model | ||
|
||
def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): | ||
self.optimizer.zero_grad() | ||
if unrolled: | ||
self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer) | ||
else: | ||
self._backward_step(input_valid, target_valid) | ||
self.optimizer.step() | ||
|
||
def _backward_step(self, input_valid, target_valid): | ||
loss = self.model._loss(input_valid, target_valid) | ||
loss.backward() | ||
|
||
def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): | ||
unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer) | ||
unrolled_loss = unrolled_model._loss(input_valid, target_valid) | ||
|
||
unrolled_loss.backward() | ||
dalpha = [v.grad for v in unrolled_model.arch_parameters()] | ||
vector = [v.grad.data for v in unrolled_model.parameters()] | ||
implicit_grads = self._hessian_vector_product(vector, input_train, target_train) | ||
|
||
for g, ig in zip(dalpha, implicit_grads): | ||
g.data.sub_(eta, ig.data) | ||
|
||
for v, g in zip(self.model.arch_parameters(), dalpha): | ||
if v.grad is None: | ||
v.grad = Variable(g.data) | ||
else: | ||
v.grad.data.copy_(g.data) | ||
|
||
def _construct_model_from_theta(self, theta): | ||
model_new = self.model.new() | ||
model_dict = self.model.state_dict() | ||
|
||
params, offset = {}, 0 | ||
for k, v in self.model.named_parameters(): | ||
v_length = np.prod(v.size()) | ||
params[k] = theta[offset: offset+v_length].view(v.size()) | ||
offset += v_length | ||
|
||
assert offset == len(theta) | ||
model_dict.update(params) | ||
model_new.load_state_dict(model_dict) | ||
return model_new.cuda() | ||
|
||
def _hessian_vector_product(self, vector, input, target, r=1e-2): | ||
R = r / _concat(vector).norm() | ||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.add_(R, v) | ||
loss = self.model._loss(input, target) | ||
grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) | ||
|
||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.sub_(2*R, v) | ||
loss = self.model._loss(input, target) | ||
grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) | ||
|
||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.add_(R, v) | ||
|
||
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from collections import namedtuple | ||
|
||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||
|
||
PRIMITIVES = [ | ||
'none', | ||
'max_pool_3x3', | ||
'avg_pool_3x3', | ||
'skip_connect', | ||
# "MB_3x3_3", | ||
# "MB_3x3_3_HS", | ||
# "MB_3x3_6", | ||
# "MB_3x3_6_HS", | ||
# "MB_5x5_3", | ||
# "MB_5x5_3_HS", | ||
# "MB_5x5_6", | ||
# "MB_5x5_6_HS", | ||
# "MB_7x7_3", | ||
# "MB_7x7_3_HS", | ||
# "MB_7x7_6", | ||
# "MB_7x7_6_HS", | ||
'sep_conv_3x3', | ||
'sep_conv_5x5', | ||
'dil_conv_3x3', | ||
'dil_conv_5x5' | ||
] | ||
|
||
NASNet = Genotype( | ||
normal = [ | ||
('sep_conv_5x5', 1), | ||
('sep_conv_3x3', 0), | ||
('sep_conv_5x5', 0), | ||
('sep_conv_3x3', 0), | ||
('avg_pool_3x3', 1), | ||
('skip_connect', 0), | ||
('avg_pool_3x3', 0), | ||
('avg_pool_3x3', 0), | ||
('sep_conv_3x3', 1), | ||
('skip_connect', 1), | ||
], | ||
normal_concat = [2, 3, 4, 5, 6], | ||
reduce = [ | ||
('sep_conv_5x5', 1), | ||
('sep_conv_7x7', 0), | ||
('max_pool_3x3', 1), | ||
('sep_conv_7x7', 0), | ||
('avg_pool_3x3', 1), | ||
('sep_conv_5x5', 0), | ||
('skip_connect', 3), | ||
('avg_pool_3x3', 2), | ||
('sep_conv_3x3', 2), | ||
('max_pool_3x3', 1), | ||
], | ||
reduce_concat = [4, 5, 6], | ||
) | ||
|
||
AmoebaNet = Genotype( | ||
normal = [ | ||
('avg_pool_3x3', 0), | ||
('max_pool_3x3', 1), | ||
('sep_conv_3x3', 0), | ||
('sep_conv_5x5', 2), | ||
('sep_conv_3x3', 0), | ||
('avg_pool_3x3', 3), | ||
('sep_conv_3x3', 1), | ||
('skip_connect', 1), | ||
('skip_connect', 0), | ||
('avg_pool_3x3', 1), | ||
], | ||
normal_concat = [4, 5, 6], | ||
reduce = [ | ||
('avg_pool_3x3', 0), | ||
('sep_conv_3x3', 1), | ||
('max_pool_3x3', 0), | ||
('sep_conv_7x7', 2), | ||
('sep_conv_7x7', 0), | ||
('avg_pool_3x3', 1), | ||
('max_pool_3x3', 0), | ||
('max_pool_3x3', 1), | ||
('conv_7x1_1x7', 0), | ||
('sep_conv_3x3', 5), | ||
], | ||
reduce_concat = [3, 4, 6] | ||
) | ||
|
||
DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5]) | ||
DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) | ||
|
||
|
||
PC_DARTS_cifar = Genotype(normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('avg_pool_3x3', 0), ('dil_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('sep_conv_5x5', 1), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6)) | ||
PC_DARTS_image = Genotype(normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('dil_conv_5x5', 4)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 0), ('skip_connect', 1), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 3)], reduce_concat=range(2, 6)) | ||
|
||
WAPC_DARTS_cifar1 = [1,1, 2, 7, 7, 8, 5, 8, 6, 1, 8, 1, 8, 1, 7, 8, 8, 8, 7, 8, 8, 8, 1, 1, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 8, 8, 8, 8, 1, 1, 1, 1, 1, 4] | ||
WAPC_DARTS_cifar2 =[9, 7, 7, 9, 5, 1, 5, 9, 9, 5, 9, 5, 9, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 7, 9, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | ||
WAPC_DARTS_cifar3 =[9, 9, 7, 9, 5, 8, 9, 9, 6, 5, 9, 7, 7, 7, 7, 9, 7, 9, 9, 9, 9, 9, 9, 9, 9, 1, 9, 1, 9, 1, 9, 9, 9, 9, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 7] | ||
WAPC_DARTS_cifar4 =[9, 1, 8, 9, 9, 9, 9, 9, 7, 9, 9, 9, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]#search-EXP-20190924-112226 95.5 | ||
WAPC_DARTS_cifar6 =[2, 7, 9, 8, 8, 9, 7, 9, 9, 9, 9, 9, 8, 9, 9, 9, 9, 9, 9, 9, 8, 9, 8, 9, 9]#search-EXP-20190926-090819 95.47 | ||
WAPC_DARTS_cifar7 =[1, 7, 8, 5, 9, 9, 7, 9, 9, 9, 1, 8, 1, 9, 5, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]#search-EXP-20190925-232116 | ||
WAPC_DARTS_cifar5 =[9, 7, 7, 9, 5, 1, 5, 8, 8, 9, 7, 6, 7, 9, 9, 1, 9, 1, 1, 9, 9, 9, 9, 9, 9, 9, 8, 1, 9, 9, 9, 9, 9, 9, 9] | ||
WAPC_DARTS_cifar8 =[9, 5, 7, 5, 5, 9, 9, 9, 9, 9, 9, 9, 8, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 9, 9]#search-EXP-20190927-004035 95.7 | ||
WAPC_DARTS_cifar9 =[1, 4, 5, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6]#t=6 search-EXP-20190928-171533 95.5 | ||
|
||
WAPC_nodes=[[0, 1], [0, 2], [0, 2], [0, 4], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [0, 3], [1, 3], [0, 1], [1, 2], [0, 1], [0, 4], [0, 1], [1, 2], [0, 1], [2, 4], [0, 1], [0, 1], [1, 2], [3, 4]] #96,48 channels 96.5 | ||
WAPC_DARTS_cifar_yuanshi=[4, 1, 5, 4, 5, 4, 5, 1, 4, 4, 4, 4, 4, 4, 3, 5, 3, 1, 5, 1, 5, 4, 1, 5, 4, 4, 6, 5, 4, 5, 5, 5, 1, 1, 1, 7, 7, 1, 5, 4, 7, 1, 5, 1, 4, 4, 4, 5, 5, 7, 5, 7, 4, 5, 5, 5, 1, 5, 5, 5, 4, 5, 4, 1, 1, 7, 7, 1, 1, 1, 5, 5, 5, 4, 5, 5, 5, 7, 7, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 7, 5, 5, 7, 7, 1] | ||
WAPC_nodes2=[[0, 1], [0, 2], [0, 2], [1, 3], [0, 1], [0, 1], [1, 2], [1, 4], [0, 1], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 2], [0, 1], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [3, 4]]# 96.54 channels36 96.92 600epochs | ||
WAPC_DARTS_cifar_yuanshi2=[1, 3, 1, 1, 4, 4, 1, 4, 5, 5, 5, 4, 6, 5, 4, 5, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 1, 5, 1, 4, 1, 4, 1, 2, 1, 1, 5, 1, 4, 1, 4, 1, 1, 4, 4, 1, 6, 7, 7, 5, 5, 7, 7, 1, 1, 5, 5, 4, 7, 5, 5, 1, 5, 5, 5, 7, 7, 1, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 4, 7, 5, 1, 7, 5, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 1, 5, 7, 1, 1, 1] | ||
WAPC_nodes3=[[0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 4], [0, 1], [0, 2], [0, 1], [1, 4], [0, 1], [0, 1], [1, 2], [1, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [0, 1], [1, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 4], [0, 1], [0, 1], [1, 2], [3, 4]]#96.59 | ||
WAPC_DARTS_cifar_yuanshi3=[3, 3, 1, 4, 1, 3, 5, 7, 4, 1, 1, 1, 1, 1, 1, 5, 3, 4, 5, 5, 5, 1, 4, 1, 4, 4, 1, 4, 4, 5, 5, 4, 5, 4, 5, 4, 4, 4, 5, 5, 1, 4, 1, 1, 5, 4, 1, 1, 5, 5, 6, 5, 5, 5, 1, 1, 5, 4, 1, 1, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 1, 4, 6, 7, 1, 1, 1, 1, 5, 5, 7, 1, 7, 5, 1, 4, 5, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 7, 7, 5, 7, 1, 5, 5, 1, 1, 5, 5, 7, 1, 1] | ||
WAPC_nodes4=[[0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 3], [0, 2], [0, 1], [1, 2], [0, 3], [1, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [0, 3], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [2, 3], [3, 4]] | ||
WAPC_DARTS_cifar_yuanshi4=[4, 1, 5, 4, 1, 1, 6, 3, 5, 1, 1, 5, 5, 1, 3, 5, 4, 4, 1, 3, 4, 4, 4, 4, 4, 5, 4, 5, 1, 4, 3, 4, 4, 4, 2, 4, 1, 5, 4, 4, 5, 5, 5, 1, 5, 5, 4, 1, 5, 7, 4, 4, 1, 5, 1, 1, 5, 4, 4, 4, 4, 1, 6, 1, 1, 1, 1, 5, 1, 4, 1, 1, 4, 1, 5, 1, 1, 7, 7, 4, 5, 4, 5, 1, 5, 5, 5, 1, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 1, 7, 7, 7, 7, 7, 5, 7, 1, 7, 7, 7, 7, 7, 7, 1, 7, 1, 5, 5, 1, 5, 5, 7, 7, 5, 7, 1] | ||
#epoch 60 batch 160 | ||
WAPC_nodes5=[[0, 1], [0, 1], [0, 2], [3, 4], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 2], [0, 3], [2, 4], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 2], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [1, 2], [0, 1], [1, 4], [0, 1], [1, 2], [2, 3], [2, 4]] | ||
WAPC_DARTS_cifar_yuanshi5=[4, 1, 4, 4, 4, 3, 1, 3, 1, 5, 4, 4, 1, 1, 5, 5, 3, 1, 5, 3, 4, 5, 5, 5, 4, 4, 4, 7, 1, 1, 4, 4, 4, 4, 4, 4, 4, 2, 5, 5, 5, 4, 5, 1, 5, 5, 1, 1, 5, 7, 5, 3, 5, 7, 5, 5, 5, 4, 5, 5, 4, 7, 6, 6, 7, 4, 4, 4, 5, 1, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 2, 1, 7, 5, 5, 5, 5, 5, 7, 2, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 5, 7, 7, 1, 7, 7, 7, 1, 7, 5, 1, 7, 7, 7, 7, 7, 1, 5, 5, 7, 1, 1, 5, 7, 5, 1, 5, 7, 1, 1, 1] | ||
#epoch 60 batch 160 | ||
WAPC_nodes6=[[0, 1], [0, 2], [0, 2], [1, 4], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [1, 2], [0, 3], [0, 2], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [0, 2], [1, 3], [0, 1], [0, 1], [1, 2], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [1, 2], [1, 3], [1, 3], [0, 1], [1, 2], [2, 3], [3, 4]]#96.769 600epochs | ||
WAPC_DARTS_cifar_yuanshi6=[1, 2, 3, 4, 5, 7, 1, 4, 1, 4, 4, 7, 1, 5, 1, 5, 4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 4, 5, 4, 1, 3, 1, 1, 1, 1, 5, 1, 7, 1, 1, 7, 1, 5, 2, 1, 4, 1, 5, 5, 1, 4, 4, 5, 1, 1, 1, 1, 4, 5, 5, 7, 4, 4, 1, 4, 4, 4, 5, 4, 6, 1, 7, 4, 1, 4, 5, 1, 1, 1, 4, 4, 5, 4, 1, 5, 5, 5, 5, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 1, 7, 5, 5, 1, 5, 5, 1, 7, 5, 7, 7, 1, 1] | ||
#cifar100 layers10 81.519 | ||
WAPC_nodes7=[[0, 1], [0, 2], [0, 2], [0, 4], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [1, 2], [0, 2], [2, 4], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [1, 2], [0, 4], [0, 1], [0, 1], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [1, 4], [0, 1], [0, 1], [0, 2], [1, 4], [0, 1], [0, 1], [0, 1], [2, 3]] | ||
WAPC_DARTS_cifar_yuanshi7=[1, 1, 5, 4, 5, 1, 4, 3, 4, 4, 3, 5, 5, 4, 3, 4, 3, 5, 1, 6, 5, 3, 3, 1, 5, 4, 1, 1, 4, 4, 4, 1, 4, 4, 5, 4, 4, 5, 4, 5, 1, 1, 5, 7, 4, 1, 7, 1, 1, 7, 7, 4, 5, 1, 1, 1, 5, 4, 1, 1, 4, 5, 4, 7, 7, 5, 1, 1, 7, 6, 2, 5, 1, 4, 4, 4, 1, 1, 4, 1, 4, 4, 1, 1, 5, 5, 5, 7, 7, 1, 5, 1, 1, 5, 5, 7, 7, 1, 1, 7, 7, 7, 7, 5, 7, 5, 7, 7, 7, 7, 7, 1, 1, 7, 7, 7, 7, 1, 7, 7, 1, 7, 7, 7, 7, 7, 5, 5, 7, 7, 7, 5, 7, 5, 7, 5, 5, 1, 1, 1] | ||
#cifar100 layers10 wd 3e-4 | ||
WAPC_nodes8=[[0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 3], [1, 3], [0, 1], [0, 2], [1, 3], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [2, 3], [0, 4], [0, 1], [0, 1], [1, 3], [0, 1], [0, 1], [0, 1], [0, 2], [1, 4], [0, 1], [0, 1], [0, 2], [2, 3]] | ||
WAPC_DARTS_cifar_yuanshi8=[3, 3, 5, 4, 1, 3, 1, 3, 4, 3, 3, 1, 5, 6, 3, 4, 4, 4, 5, 5, 4, 5, 4, 5, 5, 4, 5, 5, 4, 3, 4, 4, 1, 1, 4, 2, 1, 4, 4, 4, 5, 5, 1, 1, 2, 5, 1, 5, 5, 7, 4, 5, 5, 1, 7, 1, 7, 4, 4, 1, 1, 1, 4, 1, 1, 4, 5, 5, 5, 7, 1, 7, 4, 5, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 5, 1, 5, 5, 1, 5, 5, 1, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 5, 7, 7, 5, 7, 7, 1, 7, 7, 4, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5, 5, 1, 1, 5] | ||
|
||
|
||
PCDARTS = WAPC_DARTS_cifar_yuanshi8 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 | ||
import tensorflow as tf | ||
import numpy as np | ||
import scipy.misc | ||
from datetime import datetime | ||
|
||
try: | ||
from StringIO import StringIO # Python 2.7 | ||
except ImportError: | ||
from io import BytesIO # Python 3.x | ||
|
||
|
||
class Logger(object): | ||
|
||
def __init__(self, log_dir): | ||
"""Create a summary writer logging to log_dir.""" | ||
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) | ||
log_dir=log_dir+TIMESTAMP | ||
self.writer = tf.summary.FileWriter(log_dir) | ||
|
||
def scalar_summary(self, tag, value, step): | ||
"""Log a scalar variable.""" | ||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) | ||
self.writer.add_summary(summary, step) | ||
|
||
def image_summary(self, tag, images, step): | ||
"""Log a list of images.""" | ||
|
||
img_summaries = [] | ||
for i, img in enumerate(images): | ||
# Write the image to a string | ||
try: | ||
s = StringIO() | ||
except: | ||
s = BytesIO() | ||
scipy.misc.toimage(img).save(s, format="png") | ||
|
||
# Create an Image object | ||
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), | ||
height=img.shape[0], | ||
width=img.shape[1]) | ||
# Create a Summary value | ||
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) | ||
|
||
# Create and write Summary | ||
summary = tf.Summary(value=img_summaries) | ||
self.writer.add_summary(summary, step) | ||
|
||
def histo_summary(self, tag, values, step, bins=1000): | ||
"""Log a histogram of the tensor of values.""" | ||
|
||
# Create a histogram using numpy | ||
counts, bin_edges = np.histogram(values, bins=bins) | ||
|
||
# Fill the fields of the histogram proto | ||
hist = tf.HistogramProto() | ||
hist.min = float(np.min(values)) | ||
hist.max = float(np.max(values)) | ||
hist.num = int(np.prod(values.shape)) | ||
hist.sum = float(np.sum(values)) | ||
hist.sum_squares = float(np.sum(values ** 2)) | ||
|
||
# Drop the start of the first bin | ||
bin_edges = bin_edges[1:] | ||
|
||
# Add bin edges and counts | ||
for edge in bin_edges: | ||
hist.bucket_limit.append(edge) | ||
for c in counts: | ||
hist.bucket.append(c) | ||
|
||
# Create and write Summary | ||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) | ||
self.writer.add_summary(summary, step) | ||
self.writer.flush() |
Oops, something went wrong.