diff --git a/src/autograd_gpu/gcnode.jl b/src/autograd_gpu/gcnode.jl index 9f5eeb5a1..89588d17b 100644 --- a/src/autograd_gpu/gcnode.jl +++ b/src/autograd_gpu/gcnode.jl @@ -1,5 +1,5 @@ using CUDA: CuArray, unsafe_free! -using Knet.CuArrays: cuarrays # Profiling shows most of the time is spent in cuarrays +using Knet.CuArrays: cuarrays using Knet.KnetArrays: cuallocator using AutoGrad: Result, Node, Tape @@ -29,9 +29,6 @@ const gcnode_dict = Dict{UInt,WeakRef}() # ObjectId(::Node) for keys to allow gc after gcnode is done. const gcnode_index = Dict{UInt,Int}() -# There is no explicit call to gcnode_init by the user. gcnode just resets everything if the -# input Tape is different from gcnode_tape: -const gcnode_tape = WeakRef(nothing) # During the backward step parents of a node (who have lower indices) may have their # outgrads modified, thus new references to CuArrays that we have already indexed may @@ -45,11 +42,9 @@ function gcnode_setindex!(c::CuArray,v::Int) end function gcnode_init(tape::Tape) - gcnode_tape.value = tape empty!(gcnode_index) empty!(gcnode_queue) empty!(gcnode_dict) - tape isa Tape || return @inbounds for (i,n) in enumerate(tape.list) gcnode_index[objectid(n)] = i if n.Value isa Result @@ -60,14 +55,16 @@ function gcnode_init(tape::Tape) for c in cuarrays(n.outgrad); gcnode_setindex!(c,0); end end end + # Mark this tape so we know when gcnode is called with a new tape and gcnode_init is needed + tape.dict[gcnode_null] = gcnode_null_node end + function gcnode(n::Node, tape::Tape) cuallocator[] || return knetgcnode(n,tape) - if tape !== gcnode_tape.value + if !haskey(tape.dict, gcnode_null) gcnode_init(tape) end - tape isa Tape || return ni = gcnode_index[objectid(n)] if n.Value isa Result for c in cuarrays(n.outgrad); gcnode_setindex!(c, ni); end @@ -84,8 +81,9 @@ function gcnode(n::Node, tape::Tape) end while !isempty(gcnode_queue) && peek(gcnode_queue)[2] >= ni ## 5.62μs (cid,v) = dequeue_pair!(gcnode_queue) + @assert v == ni c = gcnode_dict[cid].value - if v == ni && c isa CuArray ## c turns into nothing if gc'ed + if c isa CuArray ## c turns into nothing if gc'ed unsafe_free!(c) end end @@ -94,7 +92,6 @@ function gcnode(n::Node, tape::Tape) end end -const gcnode_null = Result{Nothing}(nothing,nothing,nothing,nothing) - - +const gcnode_null = Result{Nothing}(nothing,nothing,nothing,nothing) +const gcnode_null_node = Node(gcnode_null)