diff --git a/Manifest.toml b/Manifest.toml index 0c50fa2..dbb15ab 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -21,12 +21,6 @@ version = "1.0.0" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[BinDeps]] -deps = ["Libdl", "Pkg", "SHA", "URIParser", "Unicode"] -git-tree-sha1 = "66158ad56b4bf6cc8413b37d0b7bc52402682764" -uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" -version = "1.0.0" - [[BinaryProvider]] deps = ["Libdl", "SHA"] git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" @@ -40,21 +34,21 @@ version = "0.2.0" [[CUDAapi]] deps = ["Libdl", "Logging"] -git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3" +git-tree-sha1 = "56a813440ac98a1aa64672ab460a1512552211a7" uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" -version = "2.0.0" +version = "2.1.0" [[CUDAdrv]] deps = ["CEnum", "CUDAapi", "Printf"] -git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921" +git-tree-sha1 = "1fce616fa0806c67c133eb1d2f68f0f1a7504665" uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" -version = "4.0.4" +version = "5.0.1" [[CUDAnative]] deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"] -git-tree-sha1 = "a67b38619d1fa131027bac1c4a81f0012254d1fd" +git-tree-sha1 = "6e11d5c2c91fc623952e94c4fb73f9c4db74795a" uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" -version = "2.6.0" +version = "2.7.0" [[CodecZlib]] deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] @@ -82,9 +76,9 @@ version = "0.2.0" [[CuArrays]] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "e99db1397ce85975203a9d736ab37534730996ca" +git-tree-sha1 = "51fbe053dea29ed2513e02d38380007310cf4c4b" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" -version = "1.5.0" +version = "1.6.0" [[DataAPI]] git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252" @@ -93,9 +87,9 @@ version = "1.1.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10" +git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.17.6" +version = "0.17.7" [[Dates]] deps = ["Printf"] @@ -107,15 +101,15 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["StaticArrays"] -git-tree-sha1 = "b5b37c47c5cee040a47d02cf65144ab7c5d8aef6" +git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.1" +version = "1.0.2" [[DiffRules]] deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed" +git-tree-sha1 = "10dca52cf6d4a62d82528262921daf63b99704a2" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.1.0" +version = "1.0.0" [[Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -140,9 +134,9 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "0.8.2" [[FixedPointNumbers]] -git-tree-sha1 = "bd1386f890e172ef38e1c735cda58cbf004a7c9a" +git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.7.0" +version = "0.6.1" [[Flux]] deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "CuArrays", "DelimitedFiles", "Juno", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] @@ -186,9 +180,9 @@ version = "0.7.2" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc" +git-tree-sha1 = "1d08d7e4250f452f6cb20e4574daaebfdbee0ff7" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "1.3.2" +version = "1.3.3" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -235,16 +229,22 @@ version = "0.4.3" uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] -git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] +git-tree-sha1 = "135c0de4794d5e214b06f1fb4787af4a72896e61" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.0" +version = "0.6.2" [[NaNMath]] git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.3" +[[OpenSpecFun_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "65f672edebf3f4e613ddf37db9dcbd7a407e5e90" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+1" + [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" @@ -278,10 +278,10 @@ uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "0.2.0" [[Requires]] -deps = ["Test"] -git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +deps = ["UUIDs"] +git-tree-sha1 = "999513b7dea8ac17359ed50ae8ea089e4464e35e" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "0.5.2" +version = "1.0.0" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -303,10 +303,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["BinDeps", "BinaryProvider", "Libdl"] -git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e" +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "268052ee908b2c086cc0011f528694f02f3e2408" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.8.0" +version = "0.9.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] @@ -340,12 +340,6 @@ git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" version = "0.9.5" -[[URIParser]] -deps = ["Test", "Unicode"] -git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" -uuid = "30578b45-9adc-5946-b283-645ec420af67" -version = "0.4.0" - [[UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -361,9 +355,9 @@ version = "0.8.3" [[Zygote]] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b" +git-tree-sha1 = "e353adc2b1026114c7ba2a112f8317dd054644c2" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.4.1" +version = "0.4.3" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/Project.toml b/Project.toml index 8ef652a..c5205fc 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Niklas Heim "] version = "0.1.0" [deps] +CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/cmean_gaussian.jl b/src/cmean_gaussian.jl index 9de3ed5..5eec96c 100644 --- a/src/cmean_gaussian.jl +++ b/src/cmean_gaussian.jl @@ -29,18 +29,21 @@ julia> rand(p, ones(2)) 0.0767166501426535 ``` """ -struct CMeanGaussian{V<:AbstractVar,S<:AbstractArray,M} <: AbstractCGaussian +struct CMeanGaussian{V<:AbstractVar,M,S<:AbstractArray} <: AbstractCGaussian mapping::M σ::S xlength::Int _nograd::Dict{Symbol,Bool} end +CMeanGaussian{V}(m::M, σ::S, xlength::Int, d::Dict{Symbol,Bool}) where {V,M,S} = + CMeanGaussian{V,M,S}(m,σ,xlength,d) + function CMeanGaussian{V}(m::M, σ, xlength::Int) where {V,M} _nograd = Dict(:σ => σ isa NoGradArray) σ = _nograd[:σ] ? σ.data : σ S = typeof(σ) - CMeanGaussian{V,S,M}(m, σ, xlength, _nograd) + CMeanGaussian{V,M,S}(m, σ, xlength, _nograd) end CMeanGaussian{DiagVar}(m, σ) = CMeanGaussian{DiagVar}(m, σ, size(σ,1)) @@ -62,10 +65,10 @@ end mean_var(p::CMeanGaussian, z::AbstractArray) = (mean(p, z), variance(p, z)) # make sure that parameteric constructor is called... -function Flux.functor(p::CMeanGaussian{V,S,M}) where {V,S,M} +function Flux.functor(p::CMeanGaussian{V,M,S}) where {V,M,S} fs = fieldnames(typeof(p)) nt = (; (name=>getfield(p, name) for name in fs)...) - nt, y -> CMeanGaussian{V,S,M}(y...) + nt, y -> CMeanGaussian{V}(y...) end function Flux.trainable(p::CMeanGaussian) diff --git a/src/cmeanvar_gaussian.jl b/src/cmeanvar_gaussian.jl index cb794b0..24783d1 100644 --- a/src/cmeanvar_gaussian.jl +++ b/src/cmeanvar_gaussian.jl @@ -54,10 +54,10 @@ function mean_var(p::CMeanVarGaussian{ScalarVar}, z::AbstractArray) end # make sure that parameteric constructor is called... -function Flux.functor(p::CMeanVarGaussian{V,M}) where {V,M} +function Flux.functor(p::CMeanVarGaussian{V}) where V fs = fieldnames(typeof(p)) nt = (; (name=>getfield(p, name) for name in fs)...) - nt, y -> CMeanVarGaussian{V,M}(y...) + nt, y -> CMeanVarGaussian{V}(y...) end function Base.show(io::IO, p::CMeanVarGaussian{V}) where V diff --git a/test/runtests.jl b/test/runtests.jl index fe38ab5..d0788a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ if Flux.use_cuda[] using CuArrays end include("abstract_pdf.jl") include("gaussian.jl") -include("cmean_gaussian.jl") +include("nogradarray.jl") include("cmeanvar_gaussian.jl") +include("cmean_gaussian.jl") include("constspec_gaussian.jl") -include("nogradarray.jl")