diff --git a/Container.lua b/Container.lua index bf31dff45..ce6cd4819 100644 --- a/Container.lua +++ b/Container.lua @@ -128,19 +128,21 @@ end function Container:clearState() -- don't call set because it might reset referenced tensors - local function clear(f) - if self[f] then - if torch.isTensor(self[f]) then - self[f] = self[f].new() - elseif type(self[f]) == 'table' then - self[f] = {} - else - self[f] = nil + local function clear(t) + if torch.isTensor(t) then + return t.new() + elseif type(t) == 'table' then + local cleared = {} + for k,v in pairs(t) do + cleared[k] = clear(v) end + return cleared + else + return nil end end - clear('output') - clear('gradInput') + if self.output then self.output = clear(self.output) end + if self.gradInput then self.gradInput = clear(self.gradInput) end if self.modules then for i,module in pairs(self.modules) do module:clearState() diff --git a/Identity.lua b/Identity.lua index 5e6ccb624..881c0c720 100644 --- a/Identity.lua +++ b/Identity.lua @@ -13,18 +13,20 @@ end function Identity:clearState() -- don't call set because it might reset referenced tensors - local function clear(f) - if self[f] then - if torch.isTensor(self[f]) then - self[f] = self[f].new() - elseif type(self[f]) == 'table' then - self[f] = {} - else - self[f] = nil + local function clear(t) + if torch.isTensor(t) then + return t.new() + elseif type(t) == 'table' then + local cleared = {} + for k,v in pairs(t) do + cleared[k] = clear(v) end + return cleared + else + return nil end end - clear('output') - clear('gradInput') + if self.output then self.output = clear(self.output) end + if self.gradInput then self.gradInput = clear(self.gradInput) end return self end diff --git a/utils.lua b/utils.lua index d81e2777b..4ef5c4ad2 100644 --- a/utils.lua +++ b/utils.lua @@ -175,18 +175,20 @@ function nn.utils.clear(self, ...) if #arg > 0 and type(arg[1]) == 'table' then arg = arg[1] end - local function clear(f) - if self[f] then - if torch.isTensor(self[f]) then - self[f]:set() - elseif type(self[f]) == 'table' then - self[f] = {} - else - self[f] = nil + local function clear(t) + if torch.isTensor(t) then + return t:set() + elseif type(t) == 'table' then + local cleared = {} + for k,v in pairs(t) do + cleared[k] = clear(v) end + return cleared + else + return nil end end - for i,v in ipairs(arg) do clear(v) end + for i,v in ipairs(arg) do if self[v] then self[v] = clear(self[v]) end end return self end