-
Notifications
You must be signed in to change notification settings - Fork 32
/
utils.py
33 lines (25 loc) · 1.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import functools
from functional import compose, partial
import tensorflow as tf
def composeAll(*args):
"""Util for multiple function composition
i.e. composed = composeAll([f, g, h])
composed(x) == f(g(h(x)))
"""
# adapted from https://docs.python.org/3.1/howto/functional.html
return partial(functools.reduce, compose)(*args)
def print_(var, name: str, first_n=5, summarize=5):
"""Util for debugging, by printing values of tf.Variable `var` during training"""
# (tf.Tensor, str, int, int) -> tf.Tensor
return tf.Print(var, [var], "{}: ".format(name), first_n=first_n,
summarize=summarize)
def get_mnist(n, mnist):
"""Returns 784-D numpy array for random MNIST digit `n`"""
assert 0 <= n <= 9, "Must specify digit 0 - 9!"
import random
SIZE = 500
imgs, labels = mnist.train.next_batch(SIZE)
idxs = iter(random.sample(range(SIZE), SIZE)) # non-in-place shuffle
for i in idxs:
if labels[i] == n:
return imgs[i] # first match