Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-lyc authored Apr 16, 2020
1 parent c251054 commit 974609b
Show file tree
Hide file tree
Showing 12 changed files with 2,204 additions and 0 deletions.
92 changes: 92 additions & 0 deletions architect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable


def _concat(xs):
return torch.cat([x.view(-1) for x in xs])


class Architect(object):

def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)

def _compute_unrolled_model(self, input, target, eta, network_optimizer):
loss = self.model._loss(input, target)
theta = _concat(self.model.parameters()).data
try:
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta
unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta))
return unrolled_model

def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
self.optimizer.zero_grad()
if unrolled:
self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer)
else:
self._backward_step(input_valid, target_valid)
self.optimizer.step()

def _backward_step(self, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid)
loss.backward()

def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer)
unrolled_loss = unrolled_model._loss(input_valid, target_valid)

unrolled_loss.backward()
dalpha = [v.grad for v in unrolled_model.arch_parameters()]
vector = [v.grad.data for v in unrolled_model.parameters()]
implicit_grads = self._hessian_vector_product(vector, input_train, target_train)

for g, ig in zip(dalpha, implicit_grads):
g.data.sub_(eta, ig.data)

for v, g in zip(self.model.arch_parameters(), dalpha):
if v.grad is None:
v.grad = Variable(g.data)
else:
v.grad.data.copy_(g.data)

def _construct_model_from_theta(self, theta):
model_new = self.model.new()
model_dict = self.model.state_dict()

params, offset = {}, 0
for k, v in self.model.named_parameters():
v_length = np.prod(v.size())
params[k] = theta[offset: offset+v_length].view(v.size())
offset += v_length

assert offset == len(theta)
model_dict.update(params)
model_new.load_state_dict(model_dict)
return model_new.cuda()

def _hessian_vector_product(self, vector, input, target, r=1e-2):
R = r / _concat(vector).norm()
for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)
loss = self.model._loss(input, target)
grads_p = torch.autograd.grad(loss, self.model.arch_parameters())

for p, v in zip(self.model.parameters(), vector):
p.data.sub_(2*R, v)
loss = self.model._loss(input, target)
grads_n = torch.autograd.grad(loss, self.model.arch_parameters())

for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)

return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]

126 changes: 126 additions & 0 deletions genotypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections import namedtuple

Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

PRIMITIVES = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
# "MB_3x3_3",
# "MB_3x3_3_HS",
# "MB_3x3_6",
# "MB_3x3_6_HS",
# "MB_5x5_3",
# "MB_5x5_3_HS",
# "MB_5x5_6",
# "MB_5x5_6_HS",
# "MB_7x7_3",
# "MB_7x7_3_HS",
# "MB_7x7_6",
# "MB_7x7_6_HS",
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5'
]

NASNet = Genotype(
normal = [
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat = [2, 3, 4, 5, 6],
reduce = [
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat = [4, 5, 6],
)

AmoebaNet = Genotype(
normal = [
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat = [4, 5, 6],
reduce = [
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat = [3, 4, 6]
)

DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5])
DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])


PC_DARTS_cifar = Genotype(normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('avg_pool_3x3', 0), ('dil_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('sep_conv_5x5', 1), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6))
PC_DARTS_image = Genotype(normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('dil_conv_5x5', 4)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 0), ('skip_connect', 1), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 3)], reduce_concat=range(2, 6))

WAPC_DARTS_cifar1 = [1,1, 2, 7, 7, 8, 5, 8, 6, 1, 8, 1, 8, 1, 7, 8, 8, 8, 7, 8, 8, 8, 1, 1, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 8, 8, 8, 8, 1, 1, 1, 1, 1, 4]
WAPC_DARTS_cifar2 =[9, 7, 7, 9, 5, 1, 5, 9, 9, 5, 9, 5, 9, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 7, 9, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
WAPC_DARTS_cifar3 =[9, 9, 7, 9, 5, 8, 9, 9, 6, 5, 9, 7, 7, 7, 7, 9, 7, 9, 9, 9, 9, 9, 9, 9, 9, 1, 9, 1, 9, 1, 9, 9, 9, 9, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 7]
WAPC_DARTS_cifar4 =[9, 1, 8, 9, 9, 9, 9, 9, 7, 9, 9, 9, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]#search-EXP-20190924-112226 95.5
WAPC_DARTS_cifar6 =[2, 7, 9, 8, 8, 9, 7, 9, 9, 9, 9, 9, 8, 9, 9, 9, 9, 9, 9, 9, 8, 9, 8, 9, 9]#search-EXP-20190926-090819 95.47
WAPC_DARTS_cifar7 =[1, 7, 8, 5, 9, 9, 7, 9, 9, 9, 1, 8, 1, 9, 5, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]#search-EXP-20190925-232116
WAPC_DARTS_cifar5 =[9, 7, 7, 9, 5, 1, 5, 8, 8, 9, 7, 6, 7, 9, 9, 1, 9, 1, 1, 9, 9, 9, 9, 9, 9, 9, 8, 1, 9, 9, 9, 9, 9, 9, 9]
WAPC_DARTS_cifar8 =[9, 5, 7, 5, 5, 9, 9, 9, 9, 9, 9, 9, 8, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 9, 9]#search-EXP-20190927-004035 95.7
WAPC_DARTS_cifar9 =[1, 4, 5, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6]#t=6 search-EXP-20190928-171533 95.5

WAPC_nodes=[[0, 1], [0, 2], [0, 2], [0, 4], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [0, 3], [1, 3], [0, 1], [1, 2], [0, 1], [0, 4], [0, 1], [1, 2], [0, 1], [2, 4], [0, 1], [0, 1], [1, 2], [3, 4]] #96,48 channels 96.5
WAPC_DARTS_cifar_yuanshi=[4, 1, 5, 4, 5, 4, 5, 1, 4, 4, 4, 4, 4, 4, 3, 5, 3, 1, 5, 1, 5, 4, 1, 5, 4, 4, 6, 5, 4, 5, 5, 5, 1, 1, 1, 7, 7, 1, 5, 4, 7, 1, 5, 1, 4, 4, 4, 5, 5, 7, 5, 7, 4, 5, 5, 5, 1, 5, 5, 5, 4, 5, 4, 1, 1, 7, 7, 1, 1, 1, 5, 5, 5, 4, 5, 5, 5, 7, 7, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 7, 5, 5, 7, 7, 1]
WAPC_nodes2=[[0, 1], [0, 2], [0, 2], [1, 3], [0, 1], [0, 1], [1, 2], [1, 4], [0, 1], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 3], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 2], [0, 1], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [3, 4]]# 96.54 channels36 96.92 600epochs
WAPC_DARTS_cifar_yuanshi2=[1, 3, 1, 1, 4, 4, 1, 4, 5, 5, 5, 4, 6, 5, 4, 5, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 1, 5, 1, 4, 1, 4, 1, 2, 1, 1, 5, 1, 4, 1, 4, 1, 1, 4, 4, 1, 6, 7, 7, 5, 5, 7, 7, 1, 1, 5, 5, 4, 7, 5, 5, 1, 5, 5, 5, 7, 7, 1, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 4, 7, 5, 1, 7, 5, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 1, 5, 7, 1, 1, 1]
WAPC_nodes3=[[0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 4], [0, 1], [0, 2], [0, 1], [1, 4], [0, 1], [0, 1], [1, 2], [1, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [0, 1], [1, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 4], [0, 1], [0, 1], [1, 2], [3, 4]]#96.59
WAPC_DARTS_cifar_yuanshi3=[3, 3, 1, 4, 1, 3, 5, 7, 4, 1, 1, 1, 1, 1, 1, 5, 3, 4, 5, 5, 5, 1, 4, 1, 4, 4, 1, 4, 4, 5, 5, 4, 5, 4, 5, 4, 4, 4, 5, 5, 1, 4, 1, 1, 5, 4, 1, 1, 5, 5, 6, 5, 5, 5, 1, 1, 5, 4, 1, 1, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 1, 4, 6, 7, 1, 1, 1, 1, 5, 5, 7, 1, 7, 5, 1, 4, 5, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 7, 7, 5, 7, 1, 5, 5, 1, 1, 5, 5, 7, 1, 1]
WAPC_nodes4=[[0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 3], [0, 2], [0, 1], [1, 2], [0, 3], [1, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [0, 3], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [2, 3], [3, 4]]
WAPC_DARTS_cifar_yuanshi4=[4, 1, 5, 4, 1, 1, 6, 3, 5, 1, 1, 5, 5, 1, 3, 5, 4, 4, 1, 3, 4, 4, 4, 4, 4, 5, 4, 5, 1, 4, 3, 4, 4, 4, 2, 4, 1, 5, 4, 4, 5, 5, 5, 1, 5, 5, 4, 1, 5, 7, 4, 4, 1, 5, 1, 1, 5, 4, 4, 4, 4, 1, 6, 1, 1, 1, 1, 5, 1, 4, 1, 1, 4, 1, 5, 1, 1, 7, 7, 4, 5, 4, 5, 1, 5, 5, 5, 1, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 1, 7, 7, 7, 7, 7, 5, 7, 1, 7, 7, 7, 7, 7, 7, 1, 7, 1, 5, 5, 1, 5, 5, 7, 7, 5, 7, 1]
#epoch 60 batch 160
WAPC_nodes5=[[0, 1], [0, 1], [0, 2], [3, 4], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 2], [0, 3], [2, 4], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 2], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [1, 2], [0, 1], [1, 4], [0, 1], [1, 2], [2, 3], [2, 4]]
WAPC_DARTS_cifar_yuanshi5=[4, 1, 4, 4, 4, 3, 1, 3, 1, 5, 4, 4, 1, 1, 5, 5, 3, 1, 5, 3, 4, 5, 5, 5, 4, 4, 4, 7, 1, 1, 4, 4, 4, 4, 4, 4, 4, 2, 5, 5, 5, 4, 5, 1, 5, 5, 1, 1, 5, 7, 5, 3, 5, 7, 5, 5, 5, 4, 5, 5, 4, 7, 6, 6, 7, 4, 4, 4, 5, 1, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 2, 1, 7, 5, 5, 5, 5, 5, 7, 2, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 5, 7, 7, 1, 7, 7, 7, 1, 7, 5, 1, 7, 7, 7, 7, 7, 1, 5, 5, 7, 1, 1, 5, 7, 5, 1, 5, 7, 1, 1, 1]
#epoch 60 batch 160
WAPC_nodes6=[[0, 1], [0, 2], [0, 2], [1, 4], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [1, 2], [0, 3], [0, 2], [0, 1], [0, 1], [0, 1], [1, 4], [0, 1], [0, 1], [0, 2], [1, 3], [0, 1], [0, 1], [1, 2], [1, 2], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [1, 2], [1, 3], [1, 3], [0, 1], [1, 2], [2, 3], [3, 4]]#96.769 600epochs
WAPC_DARTS_cifar_yuanshi6=[1, 2, 3, 4, 5, 7, 1, 4, 1, 4, 4, 7, 1, 5, 1, 5, 4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 4, 5, 4, 1, 3, 1, 1, 1, 1, 5, 1, 7, 1, 1, 7, 1, 5, 2, 1, 4, 1, 5, 5, 1, 4, 4, 5, 1, 1, 1, 1, 4, 5, 5, 7, 4, 4, 1, 4, 4, 4, 5, 4, 6, 1, 7, 4, 1, 4, 5, 1, 1, 1, 4, 4, 5, 4, 1, 5, 5, 5, 5, 7, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 1, 7, 5, 5, 1, 5, 5, 1, 7, 5, 7, 7, 1, 1]
#cifar100 layers10 81.519
WAPC_nodes7=[[0, 1], [0, 2], [0, 2], [0, 4], [0, 1], [0, 1], [0, 2], [0, 4], [0, 1], [1, 2], [0, 2], [2, 4], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [1, 2], [0, 1], [0, 1], [0, 1], [1, 2], [0, 4], [0, 1], [0, 1], [0, 3], [0, 4], [0, 1], [0, 1], [1, 2], [1, 4], [0, 1], [0, 1], [0, 2], [1, 4], [0, 1], [0, 1], [0, 1], [2, 3]]
WAPC_DARTS_cifar_yuanshi7=[1, 1, 5, 4, 5, 1, 4, 3, 4, 4, 3, 5, 5, 4, 3, 4, 3, 5, 1, 6, 5, 3, 3, 1, 5, 4, 1, 1, 4, 4, 4, 1, 4, 4, 5, 4, 4, 5, 4, 5, 1, 1, 5, 7, 4, 1, 7, 1, 1, 7, 7, 4, 5, 1, 1, 1, 5, 4, 1, 1, 4, 5, 4, 7, 7, 5, 1, 1, 7, 6, 2, 5, 1, 4, 4, 4, 1, 1, 4, 1, 4, 4, 1, 1, 5, 5, 5, 7, 7, 1, 5, 1, 1, 5, 5, 7, 7, 1, 1, 7, 7, 7, 7, 5, 7, 5, 7, 7, 7, 7, 7, 1, 1, 7, 7, 7, 7, 1, 7, 7, 1, 7, 7, 7, 7, 7, 5, 5, 7, 7, 7, 5, 7, 5, 7, 5, 5, 1, 1, 1]
#cifar100 layers10 wd 3e-4
WAPC_nodes8=[[0, 1], [0, 1], [0, 2], [0, 3], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 3], [1, 3], [0, 1], [0, 2], [1, 3], [0, 1], [0, 1], [0, 1], [0, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 2], [2, 3], [0, 4], [0, 1], [0, 1], [1, 3], [0, 1], [0, 1], [0, 1], [0, 2], [1, 4], [0, 1], [0, 1], [0, 2], [2, 3]]
WAPC_DARTS_cifar_yuanshi8=[3, 3, 5, 4, 1, 3, 1, 3, 4, 3, 3, 1, 5, 6, 3, 4, 4, 4, 5, 5, 4, 5, 4, 5, 5, 4, 5, 5, 4, 3, 4, 4, 1, 1, 4, 2, 1, 4, 4, 4, 5, 5, 1, 1, 2, 5, 1, 5, 5, 7, 4, 5, 5, 1, 7, 1, 7, 4, 4, 1, 1, 1, 4, 1, 1, 4, 5, 5, 5, 7, 1, 7, 4, 5, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 5, 1, 5, 5, 1, 5, 5, 1, 7, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 5, 7, 7, 5, 7, 7, 1, 7, 7, 4, 7, 7, 7, 7, 7, 5, 7, 7, 7, 1, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5, 5, 1, 1, 5]


PCDARTS = WAPC_DARTS_cifar_yuanshi8

75 changes: 75 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import tensorflow as tf
import numpy as np
import scipy.misc
from datetime import datetime

try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x


class Logger(object):

def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
log_dir=log_dir+TIMESTAMP
self.writer = tf.summary.FileWriter(log_dir)

def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)

def image_summary(self, tag, images, step):
"""Log a list of images."""

img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")

# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)

def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""

# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)

# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values ** 2))

# Drop the start of the first bin
bin_edges = bin_edges[1:]

# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)

# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
Loading

0 comments on commit 974609b

Please sign in to comment.