forked from pearsonkyle/Artificial-Art
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbubble_train.py
55 lines (41 loc) · 1.63 KB
/
bubble_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from dcgan import DCGAN, create_dataset
def parse_args():
parser = argparse.ArgumentParser()
help_ = "Number of training epochs"
parser.add_argument("-e", "--epochs", help=help_, default=101, type=int)
return parser.parse_args()
if __name__ == '__main__':
# parse arguments
args = parse_args()
if args is None:
exit()
x_train, y_train = create_dataset(64, 64, nSlices=20, resize=0.75, directory='SoapBubble/output/') # 3 channels = RGB
assert(x_train.shape[0]>0)
x_train /= 255
stds = np.array([np.std(x_train[i].mean(2)) for i in range(x_train.shape[0])])
gmask = stds > np.percentile(stds,25)
x_train = x_train[gmask]
# plot results to make sure data looks good!
fig, axs = plt.subplots(10, 10)
for i in range(10):
for j in range(10):
axs[i,j].imshow( x_train[ np.random.randint(x_train.shape[0]) ] )
axs[i,j].axis('off')
plt.show()
dcgan = DCGAN(img_rows = x_train[0].shape[0],
img_cols = x_train[0].shape[1],
channels = x_train[0].shape[2],
latent_dim=256,
name='bubble')
try:
dcgan.load_weights(
generator_file="generator ({}).h5".format(dcgan.name),
discriminator_file="discriminator ({}).h5".format(dcgan.name)
)
except Exception as e:
print("failed to load weights:",e)
dcgan.train(x_train, epochs=args.epochs, batch_size=32, save_interval=1000)