-
-
Notifications
You must be signed in to change notification settings - Fork 215
/
train_attention.py
37 lines (30 loc) · 1.1 KB
/
train_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/python
import tensorflow as tf
import layer
import letter
from letter import Target,batch
# learning_rate = 0.0003
target=Target.position
learning_rate = 0.0001
nClasses = letter.nClasses[target]
size = letter.max_size
training_iters = 500000
batch_size = 64
data = batch(batch_size,target)
print("data.shape %s"% data.shape)
# best with lr ~0.001
def baseline(net):
# type: (layer.net) -> None
# net.batchnorm() # start lower, else no effect
net.dense(400, activation=tf.nn.tanh)
net.regression(nClasses)# regression
# net.denseNet(40, depth=4)
return
# net=layer.net(baseline, input_shape=[28,28], output_width=nClasses,learning_rate=0.001)
# net=layer.net(alex,data, learning_rate=0.001) # NOPE!?
# net=layer.net(denseConv, input_shape=[size, size], output_width=nClasses,learning_rate=learning_rate)
# net.train(steps=50000,dropout=0.6,display_step=1,test_step=1) # debug
# net.train(steps=50000,dropout=0.6,display_step=5,test_step=20) # test
net.train(data=data, steps=training_iters, dropout=.6, display_step=10, test_step=100) # run
# net.predict() # nil=random
# net.generate(3) # nil=random