diff --git a/Project.toml b/Project.toml index 6255f3c..a1f0b13 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ConditionalDists" uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468" authors = ["Niklas Heim "] -version = "0.4.4" +version = "0.4.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ConditionalDists.jl b/src/ConditionalDists.jl index 38cfc71..6d130d9 100644 --- a/src/ConditionalDists.jl +++ b/src/ConditionalDists.jl @@ -12,6 +12,7 @@ export condition export ConditionalDistribution export ConditionalMvNormal +export SplitLayer include("cond_dist.jl") @@ -25,6 +26,11 @@ function __init__() function SplitLayer(in::Int, outs::Vector{Int}, acts::Vector) SplitLayer(Tuple(Dense(in,o,a) for (o,a) in zip(outs,acts))) end + + function SplitLayer(in::Int, outs::Vector{Int}, act=identity) + acts = [act for _ in 1:length(outs)] + SplitLayer(in, outs, acts) + end end end diff --git a/src/utils.jl b/src/utils.jl index f1332b2..a29a032 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,11 +1,8 @@ -struct SplitLayer - layers::Tuple +struct SplitLayer{T<:Tuple} + layers::T end -function SplitLayer(in::Int, outs::Vector{Int}, act=identity) - acts = [act for _ in 1:length(outs)] - SplitLayer(in, outs, acts) -end +SplitLayer(xs...) = SplitLayer(xs) function (m::SplitLayer)(x) Tuple(layer(x) for layer in m.layers) diff --git a/test/runtests.jl b/test/runtests.jl index d1d2dea..dafb0e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using ConditionalDists: BatchMvNormal, SplitLayer include("cond_dist.jl") include("cond_mvnormal.jl") +include("utils.jl") # for the BatchMvNormal tests to work BatchMvNormals have to be functors! include("batch_mvnormal.jl") diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..7006db3 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,12 @@ +@testset "SplitLayer" begin + l = SplitLayer(x->x .+ 1, _->1) + x = rand(3) + (a,b) = l(x) + @test all(a .≈ x .+ 1) + @test b == 1 + + l = SplitLayer(3,[2,4]) + (a,b) = l(x) + @test size(a) == (2,) + @test size(b) == (4,) +end