-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
116 lines (85 loc) · 2.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from decoder import Decoder
from encoder import Encoder
from Embedding import getPreTrainedEmbeddingRunner
from evaluate_captions import evaluate_captions
from runner import Runner
from settings import LSTM_HIDDEN_SIZE, EMBEDDED_SIZE, TASK_NAME
from utils import load_datasets, get_device
from glove import loadGlove
def run_network():
train_dataset, val_dataset, test_dataset = load_datasets()
vocabulary_size = len(train_dataset.dataset.vocab.wordToIndex)
computing_device = get_device()
encoder = Encoder(EMBEDDED_SIZE).to(computing_device)
decoder = Decoder(EMBEDDED_SIZE, LSTM_HIDDEN_SIZE, vocabulary_size).to(computing_device)
runner = Runner(encoder, decoder, train_dataset, val_dataset, test_dataset)
runner.train()
def show_options():
print("(1): training and validation loss for LSTM and Vanilla RNN")
print("(2): Cross Entropy and Perplexity score on test set")
print("(3): BLEU-1 and BLEU-4 scores on deterministic LSTM and Vanilla RNN")
print("(4): Experiment with Temperatures")
print("(5): Pre=trained word embeddings")
print("(q): quit program")
def task_4_1_lstm():
"""
Training and validation loss for LSTM
"""
global TASK_NAME; TASK_NAME = "Task 4-1 lstm"
global LSTM; LSTM = True
run_network()
def task_4_1_rnn():
"""
Training and validation loss for Vanilla RNN
"""
global TASK_NAME; TASK_NAME = "Task 4-1 rnn"
global LSTM; LSTM = False
run_network()
def task_4_2():
"""
Cross Entropy and Perplexity score on test set
"""
pass
def task_4_3():
"""
BLEU-1 and BLEU-4 scores on deterministic LSTM and Vanilla RNN
"""
true_captions_path = './'
print("Scoring Deterministic LSTM")
deterministic_LSTM_captions_path = './' # deterministic generation
b1, b4 = evaluate_captions(true_captions_path, deterministic_LSTM_captions_path)
print("BLEU-1 Score:", b1)
print("BLEU-4 Score:", b4)
print("Scoring Deterministic Vanilla RNN")
deterministic_vanilla_captions_path = './' # deterministic generation
b1, b4 = evaluate_captions(true_captions_path, deterministic_vanilla_captions_path)
print("BLEU-1 Score:", b1)
print("BLEU-4 Score:", b4)
def task_4_4():
"""
Experiment with temperatures
"""
global TEMPERATURE
TEMPERATURE = 1
pass
def task_4_5(): # Pre-trained word embeddings
runner = getPreTrainedEmbeddingRunner()
runner.train()
if __name__ == "__main__":
show_options()
i = ""
while i != 'q':
i = input("Please select your task: ")
i = i.lower()
if i == "1i":
task_4_1_lstm()
if i == "1ii":
task_4_1_rnn()
elif i == "2":
task_4_2()
elif i == "3":
task_4_3()
elif i == "4":
task_4_4()
elif i == "5":
task_4_5()