-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
89 lines (73 loc) · 1.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import tensorflow as tf
import numpy as np
from dataset import get_dataset, prepare_dataset
from model import get_model
dataset = get_dataset("fr-en")
print("Dataset loaded. Length:", len(dataset), "lines")
train_dataset = dataset[0:100000]
print("Train data loaded. Length:", len(train_dataset), "lines")
(encoder_input,
decoder_input,
decoder_output,
encoder_vocab,
decoder_vocab,
encoder_inverted_vocab,
decoder_inverted_vocab) = prepare_dataset(
train_dataset,
shuffle = False,
lowercase = True,
max_window_size = 20
)
transformer_model = get_model(
EMBEDDING_SIZE = 64,
ENCODER_VOCAB_SIZE = len(encoder_vocab),
DECODER_VOCAB_SIZE = len(decoder_vocab),
ENCODER_LAYERS = 2,
DECODER_LAYERS = 2,
NUMBER_HEADS = 4,
DENSE_LAYER_SIZE = 128
)
transformer_model.compile(
optimizer = "adam",
loss = [
"sparse_categorical_crossentropy"
],
metrics = [
"accuracy"
]
)
transformer_model.summary()
x = [np.array(encoder_input), np.array(decoder_input)]
y = np.array(decoder_output)
name = "transformer"
checkpoint_filepath = "./logs/transformer_ep-{epoch:02d}_loss-{loss:.2f}_acc-{accuracy:.2f}.ckpt"
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir = "logs/{}".format(name)
)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_filepath,
monitor = "val_accuracy",
mode = "max",
save_weights_only = True,
save_best_only = True,
verbose = True
)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor = "val_accuracy",
mode = "max",
patience = 2,
min_delta = 0.001,
verbose = True
)
transformer_model.fit(
x,
y,
epochs = 15,
batch_size = 32,
validation_split = 0.1,
callbacks=[
model_checkpoint_callback,
tensorboard_callback,
early_stopping_callback
]
)