-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcategorical_grid_plots.py
172 lines (151 loc) · 6.21 KB
/
categorical_grid_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import numpy as np
import tensorflow as tf
from PIL import Image
def create_image_strip(images, zoom=1.0, gutter=5):
num_images, image_height, image_width, channels = images.shape
if channels == 1:
images = images.reshape(num_images, image_height, image_width)
# add a gutter between images
effective_collage_width = num_images * (image_width + gutter) - gutter
# use white as background
start_color = (255, 255, 255)
collage = Image.new('RGB', (effective_collage_width, image_height), start_color)
offset = 0
for image_idx in range(num_images):
to_paste = Image.fromarray(
(images[image_idx] * 255).astype(np.uint8)
)
collage.paste(
to_paste,
box=(offset, 0, offset + image_width, image_height)
)
offset += image_width + gutter
if zoom != 1:
collage = collage.resize(
(
int(collage.size[0] * zoom),
int(collage.size[1] * zoom)
),
Image.NEAREST
)
return np.array(collage)
def create_continuous_noise(num_continuous, style_size, size):
style = np.random.standard_normal(size=(size, style_size))
if num_continuous > 0:
continuous = np.random.uniform(-1.0, 1.0, size=(size, num_continuous))
return np.hstack([continuous, style])
return style
class CategoricalPlotter(object):
def __init__(self,
journalist,
categorical_cardinality,
num_continuous,
style_size,
generate,
row_size=10,
zoom=2.0,
gutter=3):
self._journalist = journalist
self._gutter = gutter
self.categorical_cardinality = categorical_cardinality
self.style_size = style_size
self.num_continuous = num_continuous
self._generate = generate
self._zoom = zoom
self._placeholders = {}
self._image_summaries = {}
def generate_categorical_variations(self, session, row_size, iteration=None):
"""
连续噪声保持不变,只变化类别噪声
:param session:
:param row_size:
:param iteration:
:return:
"""
images = []
continuous_noise = create_continuous_noise(
num_continuous=self.num_continuous,
style_size=self.style_size,
size=row_size
)
for i in range(self.categorical_cardinality):
one_hot = np.zeros((row_size, self.categorical_cardinality), dtype=np.float32)
one_hot[:, i] = 1.0
z_c_vectors = np.hstack([one_hot, continuous_noise])
name = "category_%d" % (i,)
images.append(
(create_image_strip(self._generate(session, z_c_vectors), zoom=self._zoom, gutter=self._gutter), name))
self._add_image_summary(session, images, iteration=iteration)
def _get_placeholder(self, name):
if name not in self._placeholders:
self._placeholders[name] = tf.placeholder(tf.uint8, [None, None, 3])
return self._placeholders[name]
def _get_image_summary_op(self, names):
joint_name = "".join(names)
if joint_name not in self._image_summaries:
summaries = []
for name in names:
image_placeholder = self._get_placeholder(name)
decoded_image = tf.expand_dims(image_placeholder, 0)
image_summary_op = tf.summary.image(
name,
decoded_image, max_outputs=1
)
summaries.append(image_summary_op)
self._image_summaries[joint_name] = tf.summary.merge(summaries)
return self._image_summaries[joint_name]
def _add_image_summary(self, session, images, iteration=None):
feed_dict = {}
for image, placeholder_name in images:
placeholder = self._get_placeholder(placeholder_name)
feed_dict[placeholder] = image
summary_op = self._get_image_summary_op(
[name for _, name in images]
)
summary = session.run(
summary_op, feed_dict=feed_dict
)
if iteration is None:
self._journalist.add_summary(summary)
else:
self._journalist.add_summary(summary, iteration)
self._journalist.flush()
def generate_continuous_variations(self, session, row_size, variations=3, iteration=None):
"""
连续变量变化,类别变量不变
:param session:
:param row_size:
:param variations:
:param iteration:
:return:
"""
categorical_noise = np.random.randint(0, self.categorical_cardinality, size=variations)
continuous_fixed = create_continuous_noise(
num_continuous=self.num_continuous,
style_size=self.style_size,
size=variations
)
linear_variation = np.linspace(-1.0, 1.0, row_size)
images = []
for contig_idx in range(self.num_continuous):
for var_idx in range(variations):
continuous_modified = continuous_fixed[var_idx:var_idx + 1, :].repeat(
row_size, axis=0
)
# make this continuous variable vary linearly over the row:
continuous_modified[:, contig_idx] = linear_variation
one_hot = np.zeros((row_size, self.categorical_cardinality), dtype=np.float32)
one_hot[:, categorical_noise[var_idx]] = 1.0
z_c_vectors = np.hstack([one_hot, continuous_modified])
images.append(
(create_image_strip(self._generate(session, z_c_vectors), zoom=self._zoom, gutter=self._gutter),
"continuous_variable_%d, variation_%d" % (contig_idx, var_idx)))
self._add_image_summary(session, images, iteration=iteration)
def generate_images(self, session, row_size, iteration=None):
self.generate_categorical_variations(
session, row_size, iteration=iteration
)
if self.num_continuous > 0:
self.generate_continuous_variations(
session, row_size, variations=3, iteration=iteration
)