Skip to content

Commit

Permalink
Merge pull request #641 from denizyuret/dy/fix624
Browse files Browse the repository at this point in the history
gcnode debug version
  • Loading branch information
denizyuret authored Dec 10, 2020
2 parents 547f6d3 + 8a944bf commit b01c548
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/autograd_gpu/gcnode.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using CUDA: CuArray, unsafe_free!
using Knet.CuArrays: cuarrays # Profiling shows most of the time is spent in cuarrays
using Knet.CuArrays: cuarrays
using Knet.KnetArrays: cuallocator
using AutoGrad: Result, Node, Tape

Expand Down Expand Up @@ -29,9 +29,6 @@ const gcnode_dict = Dict{UInt,WeakRef}()
# ObjectId(::Node) for keys to allow gc after gcnode is done.
const gcnode_index = Dict{UInt,Int}()

# There is no explicit call to gcnode_init by the user. gcnode just resets everything if the
# input Tape is different from gcnode_tape:
const gcnode_tape = WeakRef(nothing)

# During the backward step parents of a node (who have lower indices) may have their
# outgrads modified, thus new references to CuArrays that we have already indexed may
Expand All @@ -45,11 +42,9 @@ function gcnode_setindex!(c::CuArray,v::Int)
end

function gcnode_init(tape::Tape)
gcnode_tape.value = tape
empty!(gcnode_index)
empty!(gcnode_queue)
empty!(gcnode_dict)
tape isa Tape || return
@inbounds for (i,n) in enumerate(tape.list)
gcnode_index[objectid(n)] = i
if n.Value isa Result
Expand All @@ -60,14 +55,16 @@ function gcnode_init(tape::Tape)
for c in cuarrays(n.outgrad); gcnode_setindex!(c,0); end
end
end
# Mark this tape so we know when gcnode is called with a new tape and gcnode_init is needed
tape.dict[gcnode_null] = gcnode_null_node
end


function gcnode(n::Node, tape::Tape)
cuallocator[] || return knetgcnode(n,tape)
if tape !== gcnode_tape.value
if !haskey(tape.dict, gcnode_null)
gcnode_init(tape)
end
tape isa Tape || return
ni = gcnode_index[objectid(n)]
if n.Value isa Result
for c in cuarrays(n.outgrad); gcnode_setindex!(c, ni); end
Expand All @@ -84,8 +81,9 @@ function gcnode(n::Node, tape::Tape)
end
while !isempty(gcnode_queue) && peek(gcnode_queue)[2] >= ni ## 5.62μs
(cid,v) = dequeue_pair!(gcnode_queue)
@assert v == ni
c = gcnode_dict[cid].value
if v == ni && c isa CuArray ## c turns into nothing if gc'ed
if c isa CuArray ## c turns into nothing if gc'ed
unsafe_free!(c)
end
end
Expand All @@ -94,7 +92,6 @@ function gcnode(n::Node, tape::Tape)
end
end

const gcnode_null = Result{Nothing}(nothing,nothing,nothing,nothing)



const gcnode_null = Result{Nothing}(nothing,nothing,nothing,nothing)
const gcnode_null_node = Node(gcnode_null)

0 comments on commit b01c548

Please sign in to comment.