-
Notifications
You must be signed in to change notification settings - Fork 10
/
train.lua
121 lines (98 loc) · 3.47 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
require 'neuralconvo'
require 'xlua'
cmd = torch.CmdLine()
cmd:text('Options:')
cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)')
cmd:option('--minWordFreq', 1, 'minimum frequency of words kept in vocab')
cmd:option('--cuda', false, 'use CUDA')
cmd:option('--opencl', false, 'use opencl')
cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM')
cmd:option('--learningRate', 0.05, 'learning rate at t=0')
cmd:option('--momentum', 0.9, 'momentum')
cmd:option('--minLR', 0.00001, 'minimum learning rate')
cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR')
cmd:option('--maxEpoch', 50, 'maximum number of epochs to run')
cmd:option('--batchSize', 1000, 'number of examples to load at once')
cmd:text()
options = cmd:parse(arg)
if options.dataset == 0 then
options.dataset = nil
end
-- Data
print("-- Loading dataset")
dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"),
{
loadFirst = options.dataset,
minWordFreq = options.minWordFreq
})
print("\nDataset stats:")
print(" Vocabulary size: " .. dataset.wordsCount)
print(" Examples: " .. dataset.examplesCount)
-- Model
model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize)
model.goToken = dataset.goToken
model.eosToken = dataset.eosToken
-- Training parameters
model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
model.learningRate = options.learningRate
model.momentum = options.momentum
local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch
local minMeanError = nil
-- Enabled CUDA
if options.cuda then
require 'cutorch'
require 'cunn'
model:cuda()
elseif options.opencl then
require 'cltorch'
require 'clnn'
model:cl()
end
-- Run the experiment
for epoch = 1, options.maxEpoch do
print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch)
print("")
local errors = torch.Tensor(dataset.examplesCount):fill(0)
local timer = torch.Timer()
local i = 1
for examples in dataset:batches(options.batchSize) do
collectgarbage()
for _, example in ipairs(examples) do
local input, target = unpack(example)
if options.cuda then
input = input:cuda()
target = target:cuda()
elseif options.opencl then
input = input:cl()
target = target:cl()
end
local err = model:train(input, target)
-- Check if error is NaN. If so, it's probably a bug.
if err ~= err then
error("Invalid error! Exiting.")
end
errors[i] = err
xlua.progress(i, dataset.examplesCount)
i = i + 1
end
end
timer:stop()
print("\nFinished in " .. xlua.formatTime(timer:time().real) .. " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.')
print("\nEpoch stats:")
print(" LR= " .. model.learningRate)
print(" Errors: min= " .. errors:min())
print(" max= " .. errors:max())
print(" median= " .. errors:median()[1])
print(" mean= " .. errors:mean())
print(" std= " .. errors:std())
-- Save the model if it improved.
if minMeanError == nil or errors:mean() < minMeanError then
print("\n(Saving model ...)")
torch.save("data/model.t7", model)
minMeanError = errors:mean()
end
model.learningRate = model.learningRate + decayFactor
model.learningRate = math.max(options.minLR, model.learningRate)
end
-- Load testing script
require "eval"