Skip to content

Commit

Permalink
gpu fix (#12)
Browse files Browse the repository at this point in the history
make dists work on gpu :)
  • Loading branch information
vitskvara authored and nmheim committed Jan 5, 2020
1 parent e7b87da commit d7ca4fb
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
8 changes: 4 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ version = "1.1.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.6"
version = "0.17.7"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -180,9 +180,9 @@ version = "0.7.2"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc"
git-tree-sha1 = "1d08d7e4250f452f6cb20e4574daaebfdbee0ff7"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.3.2"
version = "1.3.3"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Niklas Heim <[email protected]>"]
version = "0.1.0"

[deps]
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
11 changes: 7 additions & 4 deletions src/cmean_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ julia> rand(p, ones(2))
0.0767166501426535
```
"""
struct CMeanGaussian{V<:AbstractVar,S<:AbstractArray,M} <: AbstractCGaussian
struct CMeanGaussian{V<:AbstractVar,M,S<:AbstractArray} <: AbstractCGaussian
mapping::M
σ::S
xlength::Int
_nograd::Dict{Symbol,Bool}
end

CMeanGaussian{V}(m::M, σ::S, xlength::Int, d::Dict{Symbol,Bool}) where {V,M,S} =
CMeanGaussian{V,M,S}(m,σ,xlength,d)

function CMeanGaussian{V}(m::M, σ, xlength::Int) where {V,M}
_nograd = Dict( => σ isa NoGradArray)
σ = _nograd[] ? σ.data : σ
S = typeof(σ)
CMeanGaussian{V,S,M}(m, σ, xlength, _nograd)
CMeanGaussian{V,M,S}(m, σ, xlength, _nograd)
end

CMeanGaussian{DiagVar}(m, σ) = CMeanGaussian{DiagVar}(m, σ, size(σ,1))
Expand All @@ -62,10 +65,10 @@ end
mean_var(p::CMeanGaussian, z::AbstractArray) = (mean(p, z), variance(p, z))

# make sure that parameteric constructor is called...
function Flux.functor(p::CMeanGaussian{V,S,M}) where {V,S,M}
function Flux.functor(p::CMeanGaussian{V,M,S}) where {V,M,S}
fs = fieldnames(typeof(p))
nt = (; (name=>getfield(p, name) for name in fs)...)
nt, y -> CMeanGaussian{V,S,M}(y...)
nt, y -> CMeanGaussian{V}(y...)
end

function Flux.trainable(p::CMeanGaussian)
Expand Down
4 changes: 2 additions & 2 deletions src/cmeanvar_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ function mean_var(p::CMeanVarGaussian{ScalarVar}, z::AbstractArray)
end

# make sure that parameteric constructor is called...
function Flux.functor(p::CMeanVarGaussian{V,M}) where {V,M}
function Flux.functor(p::CMeanVarGaussian{V}) where V
fs = fieldnames(typeof(p))
nt = (; (name=>getfield(p, name) for name in fs)...)
nt, y -> CMeanVarGaussian{V,M}(y...)
nt, y -> CMeanVarGaussian{V}(y...)
end

function Base.show(io::IO, p::CMeanVarGaussian{V}) where V
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if Flux.use_cuda[] using CuArrays end

include("abstract_pdf.jl")
include("gaussian.jl")
include("cmean_gaussian.jl")
include("nogradarray.jl")
include("cmeanvar_gaussian.jl")
include("cmean_gaussian.jl")
include("constspec_gaussian.jl")
include("nogradarray.jl")

0 comments on commit d7ca4fb

Please sign in to comment.