-
Notifications
You must be signed in to change notification settings - Fork 7
/
Exec_CharLM_CTC.lua
124 lines (102 loc) · 4.54 KB
/
Exec_CharLM_CTC.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
require 'nn'
require 'cudnn'
require 'cunn'
require 'cutorch'
local threads = require 'threads'
threads.Threads.serialization('threads.sharedserialize')
require 'CTC_CLM_NN_lang_multithread'
require 'SequenceError'
local final_threading = torch.class('Final_lua_threading')
local cmd = torch.CmdLine()
cmd:option('-threads',16, 'Number of threads for execution')
cmd:option('-log_probs_arr','./Test_With_Char_lm/predictions_log_vec.t7','Path where log probabilties for test data are stored')
cmd:option('-target_label','./Test_With_Char_lm/targets_vec.t7','Path where ground truth labels are stored')
cmd:option('-predictions_without_lm','./Test_With_Char_lm/predicts_vec.t7','Path where predictions generated by model without language model are stored')
cmd:option('-alpha',1.25,'alpha value based on Character level beam search algorithm')
cmd:option('-beta',0,'beta value based on Character level beam search algorithm')
cmd:option('-beam_val',40, 'Value of beam' )
cmd:option('-dictionary','dictionary', 'path to dictionary')
cmd:option('-path','/home/sbp3624/CTCSpeechRecognition/Test_With_Char_lm/', 'path to store results')
local opt = cmd:parse(arg)
function final_threading:init(nthreads)
self.pool = threads.Threads(nthreads,function() require 'CTC_CLM_NN_lang_multithread'
end)
end
-- get log probabilities
local predictions_arr = torch.load(opt.log_probs_arr)
local targets = torch.load(opt.target_label)
local predicts_without_lm = torch.load(opt.predictions_without_lm)
local numberOfSamples = #predictions_arr
--numberOfSamples=16
local alpha = opt.alpha
local beta = opt.beta
local beam_val = opt.beam_val
local sent
--local ctc_clm_obj = CTC_CLM_NN_lang_multithread('dictionary')
local final_sent_tbl={}
local nthreads = opt.threads
final_threading:init(nthreads)
local tbl_objs={}
--create objecs equal to number of threads
for i = 1,nthreads do
tbl_objs[i] = CTC_CLM_NN_lang_multithread(opt.dictionary)
end
-- execute parallely all the sentences and get WER,CER using Character Language model
local start_time = os.time()
for i = 1,numberOfSamples do
final_threading.pool:addjob(function()
print (__threadid,i)
sent = tbl_objs[__threadid]:decode_beam_search(predictions_arr[i], alpha, beta, beam_val)
return i, sent
end,
function(i, sent)
final_sent_tbl[i] = sent
end
)
end
final_threading.pool:synchronize()
final_threading.pool:terminate()
local end_time = os.time()
local targetTranscript, predictTranscript_lm, predictTranscript
local CER , WER
local cumCER = 0
local cumWER = 0
local evaluationPredictions = {}
local prev_CER,prev_WER
local prev_cumCER = 0
local prev_cumWER = 0
-- Find Final Scores
for i = 1,numberOfSamples do
targetTranscript = targets[i]
predictTranscript_lm = final_sent_tbl[i]
predictTranscript = predicts_without_lm[i]
CER = SequenceError:calculateCER(targetTranscript, predictTranscript_lm)
WER = SequenceError:calculateWER(targetTranscript, predictTranscript_lm)
cumCER = cumCER + CER
cumWER = cumWER + WER
table.insert(evaluationPredictions, { wer = WER * 100, cer = CER * 100, target = targetTranscript, prediction = predictTranscript_lm , prediction_without = predictTranscript})
prev_CER = SequenceError:calculateCER(targetTranscript, predictTranscript)
prev_WER = SequenceError:calculateWER(targetTranscript, predictTranscript)
prev_cumCER = prev_cumCER + prev_CER
prev_cumWER = prev_cumWER + prev_WER
end
local function comp(a, b) return a.wer < b.wer end
table.sort(evaluationPredictions, comp)
-- path to save results
local path = opt.path
local suffix = '_'..os.date('%Y-%m-%d_%H-%M-%S')
for index, eval in ipairs(evaluationPredictions) do
local f = assert(io.open(path..'Evaluation'..suffix..'.log', 'a'))
f:write(string.format("WER = %.2f | CER = %.2f | Text = \"%s\" | Predict = \"%s\" | Predict_without_lm = \"%s\"\n",eval.wer, eval.cer, eval.target, eval.prediction,eval.prediction_without))
f:close()
end
local averageWER = cumWER / numberOfSamples
local averageCER = cumCER / numberOfSamples
local prev_averageWER = prev_cumWER / numberOfSamples
local prev_averageCER = prev_cumCER / numberOfSamples
local f = assert(io.open(path..'Evaluation'..suffix..'.log', 'a'))
f:write(string.format("Average WER = %.2f | CER = %.2f", averageWER * 100, averageCER * 100))
f:close()
print (string.format('Average WER = %.2f Average CER = %.2f',averageWER * 100, averageCER * 100))
print (string.format('Average prev_WER = %.2f Average prev_CER = %.2f',prev_averageWER * 100, prev_averageCER * 100))
print (string.format('Total Time %.2f',(end_time-start_time)/60))