Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pannous committed Mar 2, 2018
1 parent e049638 commit b95b30b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 11 deletions.
67 changes: 67 additions & 0 deletions deep_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import scipy.misc
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import conv2d, conv2d_transpose


def pool(X):
return tf.nn.max_pool(X, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


def uppool(X):
height, width = X.get_shape().as_list()[1:3]
return tf.image.resize_images(X, (height * 2, width * 2))


# download from https://i.imgur.com/ytjR2QF.png
image = scipy.misc.imread("snail256.png").astype(np.float32) / 255.0


def make_unet(X):
depths = [16, 32, 64, 128, 256, 512]
# TODO figure out how to make batchnorm work

activation_fn = tf.nn.tanh
# convolve and half image size a few times
for depth in depths:
# X = convolution(X, kernel_size=depth, stride=3, activation_fn=activation_fn)
X = conv2d(X, depth, 3, activation_fn=activation_fn)
X = pool(X)

X = conv2d(X, depth, 3, activation_fn=activation_fn)

# upconcolve and double image size a few times
for depth in reversed(depths):
X = uppool(X)
X = conv2d_transpose(X, depth, 3, activation_fn=activation_fn)

X = conv2d(X, 3, 3, activation_fn=None)

return X


input = tf.constant(image.reshape((1, 256, 256, 3)))

output = make_unet(input)

loss = tf.reduce_mean(tf.square(input - output))

# TODO L-BFGS-B should be faster here, but could not get it to work
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

if not os.path.exists("frames"):
os.mkdir("frames")

def save(frame):
scipy.misc.imsave("frames/%d.png" % frame, sess.run(output).reshape((256, 256, 3)).clip(0, 1))

for i in range(10000):
print("\r#" + str(i), end='', flush=True)
sess.run(train_op)
if not i % 100:
save(i/100)
10 changes: 5 additions & 5 deletions extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def wrap(self):
return str(self) # TODO!?


# WOW YAY WORKS!!!!!
# WOW YAY WORKS!!!
# ONLY VIA EXPLICIT CONSTRUCTOR!
# NOOOO!! BAAAD! isinstance(my_xlist,list) FALSE !!

Expand Down Expand Up @@ -728,10 +728,10 @@ def fix_int(self, i):
def character(self, nr):
return self.item(nr)

def item(self, nr): # -1 AppleScript style !!! BUT list[0] !!!
def item(self, nr): # -1 AppleScript style ! BUT list[0] !
return self[xlist(self).fix_int(nr)]

def word(self, nr): # -1 AppleScript style !!! BUT list[0] !!!):
def word(self, nr): # -1 AppleScript style ! BUT list[0] !):
return self[xlist(self).fix_int(nr)]

def invert(self): # ! Self modifying !
Expand Down Expand Up @@ -920,7 +920,7 @@ def __sub__(self, x):
def synsets(self, param):
pass

def is_noun(self): # expensive!!!):
def is_noun(self): # expensive!):
# Sequel::InvalidOperation Invalid argument used for IS operator
return self.synsets('noun') or self.gsub(r's$', "").synsets('noun') # except False

Expand Down Expand Up @@ -1480,7 +1480,7 @@ def find_class(match=""): # all
# def __getattr__(self, attr):
# import sys
# import math
# # ruby method_missing !!!
# # ruby method_missing !
# import inspect
# for name, obj in inspect.getmembers(sys.modules['math']):
# if name==attr: return obj
Expand Down
2 changes: 1 addition & 1 deletion letter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# overfit = False
overfit = True
if overfit:
print("using OVERFIT DEBUG DATA!!!")
print("using OVERFIT DEBUG DATA!")
min_size = 24
max_size = 24
max_padding = 8
Expand Down
8 changes: 4 additions & 4 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def train(self, data=0, steps=-1, dropout=None, display_step=10, test_step=100,
acc, summary = session.run([self.accuracy, self.summaries], feed_dict=feed)
# self.summary_writer.add_summary(summary, step) # only test summaries for smoother curve
print("\rStep {:d} Loss= {:.6f} Accuracy= {:.3f} Time= {:d}s".format(step, loss, acc, seconds), end=' ')
if str(loss) == "nan": return print("\nLoss gradiant explosion, exiting!!!") # restore!
if str(loss) == "nan": return print("\nLoss gradiant explosion, exiting!") # restore!
if step % test_step == 0: self.test(step)
if step % save_step == 0 and step > 0:
print("SAVING snapshot %s" % snapshot)
Expand Down Expand Up @@ -479,20 +479,20 @@ def resume(self, session):
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if checkpoint:
if self.name and not self.name in checkpoint:
print("IGNORING checkpoint of other run : " + checkpoint + " !!!")
print("IGNORING checkpoint of other run : " + checkpoint + " !")
checkpoint = None
else:
print("NO checkpoint, nothing to resume")
if checkpoint:
print("LOADING " + checkpoint + " !!!")
print("LOADING " + checkpoint + " !")
try:
persister = tf.train.Saver(tf.global_variables())
persister.restore(session, checkpoint)
print("resume checkpoint successful!")
return True
except Exception as ex:
print(ex)
print("CANNOT LOAD checkpoint %s !!!" % checkpoint)
print("CANNOT LOAD checkpoint %s !" % checkpoint)
return False

def restore(self): # name
Expand Down
2 changes: 1 addition & 1 deletion word_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://github.com/fchollet/keras/pull/4115/files

# AtrousConv1D = AtrousConvolution1D
# YAY!!!!!!!!!!
# YAY!!!!

# https://github.com/nrTQgc/deep-anpr/commit/adfbbe7f3deeaa39bdecc36d8398434b76fdf211
# fork confusion : ^^ 13 days ago BUUUUT: # https://github.com/nrTQgc/deep-anpr two months ago!?!?!?!

0 comments on commit b95b30b

Please sign in to comment.