forked from cmusatyalab/openface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OpenFaceOptim.lua
77 lines (65 loc) · 2.46 KB
/
OpenFaceOptim.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
-- Modified from https://github.com/facebook/fbnn/blob/master/fbnn/Optim.lua.
local pl = require('pl.import_into')()
local OpenFaceOptim, parent = torch.class('OpenFaceOptim', 'nn.Optim')
function OpenFaceOptim:__init(model, optState, checkpoint_data)
parent.__init(self, model, optState, checkpoint_data)
end
local function get_device_for_module(mod)
local dev_id = nil
for name, val in pairs(mod) do
if torch.typename(val) == 'torch.CudaTensor' then
local this_dev = val:getDevice()
if this_dev ~= 0 then
-- _make sure the tensors are allocated consistently
assert(dev_id == nil or dev_id == this_dev)
dev_id = this_dev
end
end
end
return dev_id -- _may still be zero if none are allocated.
end
local function on_device_for_module(mod, f)
local this_dev = get_device_for_module(mod)
if this_dev ~= nil then
return cutorch.withDevice(this_dev, f)
end
return f()
end
function OpenFaceOptim:optimizeTriplet(optimMethod, inputs, criterion)
assert(optimMethod)
assert(inputs)
assert(criterion)
assert(self.modulesToOptState)
self.model:zeroGradParameters()
local output = self.model:forward(inputs)
local err = criterion:forward(output)
local df_do = criterion:backward(output)
self.model:backward(inputs, df_do)
-- We'll set these in the loop that iterates over each module. Get them
-- out here to be captured.
local curGrad
local curParam
local function fEvalMod(x)
return err, curGrad
end
for curMod, opt in pairs(self.modulesToOptState) do
on_device_for_module(curMod, function()
local curModParams = self.weight_bias_parameters(curMod)
-- expects either an empty table or 2 element table, one for weights
-- and one for biases
assert(pl.tablex.size(curModParams) == 0 or
pl.tablex.size(curModParams) == 2)
if curModParams then
for i, tensor in ipairs(curModParams) do
if curModParams[i] then
-- expect param, gradParam pair
curParam, curGrad = table.unpack(curModParams[i])
assert(curParam and curGrad)
optimMethod(fEvalMod, curParam, opt[i])
end
end
end
end)
end
return err, output
end