-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gpu fix #12
Conversation
src/cmean_gaussian.jl
Outdated
@@ -65,7 +65,7 @@ mean_var(p::CMeanGaussian, z::AbstractArray) = (mean(p, z), variance(p, z)) | |||
function Flux.functor(p::CMeanGaussian{V,S,M}) where {V,S,M} | |||
fs = fieldnames(typeof(p)) | |||
nt = (; (name=>getfield(p, name) for name in fs)...) | |||
nt, y -> CMeanGaussian{V,S,M}(y...) | |||
nt, y -> CMeanGaussian{V,S,typeof(gpu(p.mapping))}(y...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only fixes the mapping conversion, however the σ
field is still not converted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this is probably going to break if one has CuArrays
loaded, the model is on cpu and functor is not called in the gpu conversion process - what are the places where functor
is called?
Codecov Report
@@ Coverage Diff @@
## master #12 +/- ##
==========================================
- Coverage 92.3% 91.42% -0.88%
==========================================
Files 8 8
Lines 104 105 +1
==========================================
Hits 96 96
- Misses 8 9 +1
Continue to review full report at Codecov.
|
|
I cannot test on GPU atm, but does this work? function Flux.functor(p::CMeanGaussian{V}) where V
fs = fieldnames(typeof(p))
nt = (; (name=>getfield(p, name) for name in fs)...)
nt, y -> CMeanGaussian{V}(y...)
end |
putting a |
It helps with |
Now it fails with julia> p = CMeanGaussian{DiagVar}(mapping, var) |> gpu
ERROR: MethodError: no method matching CMeanGaussian{DiagVar,M,S} where S<:AbstractArray where M(::Dense{typeof(identity),CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Int64, ::Dict{Symbol,Bool})
Closest candidates are:
CMeanGaussian{DiagVar,M,S} where S<:AbstractArray where M(::M, ::Any, ::Int64) where {V, M} at /home/vit/.julia/dev/ConditionalDists/src/cmean_gaussian.jl:40
CMeanGaussian{DiagVar,M,S} where S<:AbstractArray where M(::Any, ::Any) at /home/vit/.julia/dev/ConditionalDists/src/cmean_gaussian.jl:46
Stacktrace:
[1] #14 at /home/vit/.julia/dev/ConditionalDists/src/cmean_gaussian.jl:68 [inlined]
[2] fmap1(::Function, ::CMeanGaussian{DiagVar,Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Array{Float32,1}}) at /home/vit/.julia/packages/Flux/oObnA/src/functor.jl:32
[3] #fmap#41(::IdDict{Any,Any}, ::typeof(fmap), ::typeof(cu), ::CMeanGaussian{DiagVar,Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Array{Float32,1}}) at /home/vit/.julia/packages/Flux/oObnA/src/functor.jl:37
[4] fmap at /home/vit/.julia/packages/Flux/oObnA/src/functor.jl:36 [inlined] |
I am really confused with this. When debugging, it seems that the overloaded |
|
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #12 +/- ##
==========================================
- Coverage 92.30% 91.42% -0.88%
==========================================
Files 8 8
Lines 104 105 +1
==========================================
Hits 96 96
- Misses 8 9 +1 ☔ View full report in Codecov by Sentry. |
Trying to solve #11.