-
Notifications
You must be signed in to change notification settings - Fork 73
/
web_backend.lua
145 lines (134 loc) · 5.06 KB
/
web_backend.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
require 'torch'
require 'nngraph'
require 'optim'
require 'lfs'
require 'nn'
require 'util.OneHot'
require 'util.misc'
JSON = (loadfile "util/JSON.lua")()
local redis = require 'redis'
local client = redis.connect('127.0.0.1', 6379)
local client2 = redis.connect('127.0.0.1', 6379)
local channels = {'cv_channel'}
local model_file = './onlie_model/model.t7'
local gpuid = 0
local seed = 123
-- check that cunn/cutorch are installed if user wants to use the GPU
if gpuid >= 0 then
local ok, cunn = pcall(require, 'cunn')
local ok2, cutorch = pcall(require, 'cutorch')
if not ok then print('package cunn not found!') end
if not ok2 then print('package cutorch not found!') end
if ok and ok2 then
print('using CUDA on GPU ' .. gpuid .. '...')
cutorch.setDevice(gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
cutorch.manualSeed(seed)
else
print('Falling back on CPU mode')
gpuid = -1 -- overwrite user setting
end
end
if not lfs.attributes(model_file, 'mode') then
print('Error: File ' .. model_file .. ' does not exist.')
end
checkpoint = torch.load(model_file)
protos = checkpoint.protos
protos.rnn:evaluate() -- put in eval mode so that dropout works properly
-- initialize the vocabulary (and its inverted version)
local vocab = checkpoint.vocab
local ivocab = {}
for c,i in pairs(vocab) do ivocab[i] = c end
-- parse characters from a string
function get_char(str)
local len = #str
local left = 0
local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc}
local unordered = {}
local start = 1
local wordLen = 0
while len ~= left do
local tmp = string.byte(str, start)
local i = #arr
while arr[i] do
if tmp >= arr[i] then
break
end
i = i - 1
end
wordLen = i + wordLen
local tmpString = string.sub(str, start, wordLen)
start = start + i
left = left + i
unordered[#unordered+1] = tmpString
end
return unordered
end
-- start listen
for msg in client:pubsub({subscribe = channels}) do
if msg.kind == 'subscribe' then
print('Subscribed to channel '..msg.channel)
elseif msg.kind == 'message' then
-- print('Received the following message from '..msg.channel.."\n "..msg.payload.."\n")
local req = JSON:decode(msg.payload)
local primetext = '|' .. req['text'] .. '| '
local session_id = req['sid']
local seed = req['seed']
local temperature = req['temp']
-- initialize the rnn state to all zeros
local current_state
local num_layers = checkpoint.opt.num_layers
current_state = {}
for L = 1,checkpoint.opt.num_layers do
-- c and h for all layers
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):float()
if gpuid >= 0 then h_init = h_init:cuda() end
table.insert(current_state, h_init:clone())
table.insert(current_state, h_init:clone())
end
state_size = #current_state
-- use input to init state
torch.manualSeed(seed)
for i,c in ipairs(get_char(primetext)) do
prev_char = vocab[c]
if prev_char then
prev_char = torch.Tensor{vocab[c]}
io.write(ivocab[prev_char[1]])
if gpuid >= 0 then prev_char = prev_char:cuda() end
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
-- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
end
end
-- start sampling/argmaxing
result = ''
not_end = true
for i=1,1000 do
-- log probabilities from the previous timestep
-- make sure the output char is not UNKNOW
real_char = 'UNKNOW'
while(real_char == 'UNKNOW') do
torch.manualSeed(seed+1)
prediction:div(temperature) -- scale by temperature
local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
real_char = ivocab[prev_char[1]]
end
-- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
result = result .. ivocab[prev_char[1]]
if string.find(result, '\n\n\n\n\n') then
not_end = false
break
end
end
if not_end then result = result .. '……' end
-- client2:set(session_id, result)
client2:setex(session_id, 100, result)
end
end