-
Notifications
You must be signed in to change notification settings - Fork 2
/
multigpu.lua
133 lines (120 loc) · 3.84 KB
/
multigpu.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
require 'cunn'
local function isContainer(module)
-- only accept standard containers
local moduleType = torch.type(module)
local containers = {
'nn.Sequential',
'nn.Concat',
'nn.DepthConcat',
'nn.Parallel',
'nn.ConcatTable',
'nn.ParallelTable',
}
for i = 1,#containers do
if moduleType == containers[i] then return true end
end
return false
end
local function copyContainer(module)
local modType = torch.type(module)
modType = modType:sub(4, #modType)
newModel = nn[modType]
if modType == 'Concat' or modType == 'DepthConcat' then
return newModel(module.dimension)
elseif modType == 'Parallel' then
return newModel(module.inputDimension, module.outputDimension)
else
return newModel()
end
end
local function cleanDPT(module)
-- This assumes this DPT was created by the function above: all the
-- module.modules are clones of the same network on different GPUs
-- hence we only need to keep one when saving the model to the disk.
local newModel
if torch.isTypeOf(module, nn.DataParallelTable) then
newModel = nn.DataParallelTable(1, true, true)
cutorch.setDevice(OPT.GPU)
newModel:add(module:get(1), OPT.GPU)
elseif isContainer(module) then
newModel = copyContainer(module)
for _, mod in ipairs(module.modules) do
newModel:add(cleanDPT(mod))
end
else
newModel = module
end
return newModel
end
local function retrieveDPT(module, nGPU, net)
-- This helps to search for DPT which deeply lies in the nn.Container, and
-- copy to different GPUs
if torch.isTypeOf(module, nn.Container) then
for i, mod in ipairs(module.modules) do
if torch.type(mod) == 'nn.DataParallelTable' then
module[i] = makeDataParallel(mod:get(1), nGPU, net)
else
retrieveDPT(mod, nGPU, net)
end
end
end
end
local function removeDPT(module)
-- This helps to replace nn.DataParallelTable with nn.Sequential which
-- deeply lies in nn.Container
if torch.isTypeOf(module, nn.Container) then
for i, mod in ipairs(module.modules) do
if torch.type(mod) == 'nn.DataParallelTable' then
module[i] = mod:get(1):clone():cuda()
else
removeDPT(mod)
end
end
end
end
function makeDataParallel(model, nGPU, net)
-- This function clones the specified model from major GPU to other GPUs
if nGPU >= 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1,true,true)
for i=1, nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
-- allow multi-threads for multi-GPUS
local netobj = net
local initFun = function()
netobj.packages()
end
model:threads(initFun)
end
cutorch.setDevice(OPT.GPU)
return model
end
function saveDataParallel(filename, model)
local tmpModel = cleanDPT(model)
torch.save(filename, tmpModel)
end
function loadDataParallel(filename, nGPU, net)
-- load require packages
net.packages()
local model = torch.load(filename)
if torch.type(model) == 'nn.DataParallelTable' then
return makeDataParallel(model:get(1), nGPU, net)
else
retrieveDPT(model, nGPU, net)
return model
end
end
function loadAndRemoveDPT(filename, net)
net.packages()
local model = torch.load(filename)
if torch.type(model) == 'nn.DataParallelTable' then
return model:get(1):clone():cuda()
else
removeDPT(model)
return model
end
end