Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create backward method with updateGradInput #147

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions gmodule.lua
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,78 @@ function gModule:accGradParameters(input,gradOutput,lr)
end
end

function gModule:backward(input,gradOutput,scale)
local function neteval(node)
if node.data.selectindex then
assert(not node.data.module, "the selectindex-handling nodes should have no module")
assert(#node.children == 1, "only the splitted node should be the input")
local child = node.children[1]
local go = getTotalGradOutput(node)
child.data.gradOutput = child.data.gradOutput or {}
assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
-- The data.gradOutput holds the to-be-summed gradients.
child.data.gradOutput[1] = child.data.gradOutput[1] or {}
assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
child.data.gradOutput[1][node.data.selectindex] = go
else
local gradOutput = getTotalGradOutput(node)
-- backward through this node
-- If no module is present, the node behaves like nn.Identity.
local gradInput
if not node.data.module then
gradInput = gradOutput
else
local input = node.data.input
-- a parameter node is captured
if input == nil and node.data.module ~= nil then
input = {}
end
if #input == 1 then
input = input[1]
end
local module = node.data.module
gradInput = module:backward(input,gradOutput,scale)
end
-- propagate the output to children
for i,child in ipairs(node.children) do
child.data.gradOutput = child.data.gradOutput or {}
local mapindex = node.data.mapindex[child.data]
local gi
if #node.children == 1 then
gi = gradInput
else
gi = gradInput[mapindex]
end
table.insert(child.data.gradOutput,gi)
end
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
if #outnode.children > 1 and #gradOutput ~= #outnode.children then
error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
for _,node in ipairs(self.backwardnodes) do
local gradOutput = node.data.gradOutput
while gradOutput and #gradOutput >0 do
table.remove(gradOutput)
end
end
-- Set the starting gradOutput.
outnode.data.gradOutput = outnode.data.gradOutput or {}
outnode.data.gradOutput[1] = gradOutput

for i,node in ipairs(self.backwardnodes) do
neteval(node)
end

assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
self.gradInput = self.innode.data.gradOutput[1]
return self.gradInput
end

function gModule:read(file)
local data = file:readObject()
for k, v in pairs(data) do
Expand Down