From 5d0c37bc119b16140e626811a92116b3e3260cf6 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Sun, 1 Nov 2020 21:37:54 +0100 Subject: [PATCH] Shared variance via SplitLayer (#43) implement learned/shared/fixed/unit variance and properly test all cases --- Project.toml | 2 +- README.md | 16 ++-- src/cond_mvnormal.jl | 58 ++++++++++--- src/utils.jl | 63 ++++++++++++++- test/cond_mvnormal.jl | 183 ++++++++++++++++++------------------------ test/utils.jl | 21 +++++ 6 files changed, 219 insertions(+), 124 deletions(-) diff --git a/Project.toml b/Project.toml index a1f0b13..080daa1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ConditionalDists" uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468" authors = ["Niklas Heim "] -version = "0.4.5" +version = "0.4.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/README.md b/README.md index e80fd67..ee0805c 100644 --- a/README.md +++ b/README.md @@ -3,17 +3,16 @@ # ConditionalDists.jl -Conditional probability distributions powered by Flux.jl and Distributions.jl. +Conditional probability distributions powered by Flux.jl and DistributionsAD.jl. The conditional PDFs that are defined in this package can be used in conjunction with Flux models to provide trainable mappings. As an example, assume you want to learn the mapping from a conditional to an MvNormal. The mapping `m` takes a vector `x` and maps it to a mean `μ` and a variance `σ`, which can be achieved by using a `ConditionalDists.SplitLayer` as the last -layer in your network like the one below: The `SplitLayer` is constructed from -`N` `Dense` layers (with same input size) and outputs `N` vectors: +layer in the network. ```julia -julia> m = SplitLayer(2, [3,4]) +julia> m = Chain(Dense(2,2,σ), SplitLayer(2, [3,4])) julia> m(rand(2)) (Float32[0.07946974, 0.13797458, 0.03939067], Float32[0.7006321, 0.37641272, 0.3586885, 0.82230335]) ``` @@ -40,8 +39,13 @@ julia> z = rand(zlength, batchsize) julia> logpdf(p,x,z) julia> rand(p, randn(zlength, 10)) ``` -The trainable parameters (of the `SplitLayer`) are accessible as usual -through `Flux.params`. The next few lines show how to optimize `p` to match a +The trainable parameters (of the `SplitLayer`) are accessible as usual through +`Flux.params`. For different variance configurations (i.e. fixed/unit variance, +etc) check the doc strings with `julia>? ConditionalMvNormal`/`julia>? +SplitLayer`. + + +The next few lines show how to optimize `p` to match a given Gaussian by using the `kl_divergence` defined in [IPMeasures.jl](https://github.com/aicenter/IPMeasures.jl). diff --git a/src/cond_mvnormal.jl b/src/cond_mvnormal.jl index 6ec086e..f1b15f4 100644 --- a/src/cond_mvnormal.jl +++ b/src/cond_mvnormal.jl @@ -1,18 +1,13 @@ """ ConditionalMvNormal(m) -Specialization of ConditionalDistribution for `MvNormal`s for performance. -Does the same as ConditionalDistribution(MvNormal,m) for vector inputs (to e.g. -mean/logpdf). For batches of inputs a `BatchMvNormal` is constructed that does +Specialization of ConditionalDistribution for `MvNormal`s (for performance with +batches of inputs). Does the same as ConditionalDistribution(MvNormal,m) +but for batches of inputs a `BatchMvNormal` is constructed that does not just map over the batch but uses faster matrix multiplications. -The mapping `m` must return either a `Tuple` with mean and variance, or just a -mean vector. If the output of `m` is just a vector, the variance is assumed to -be a fixed unit variance. - -# Examples ```julia-repl -julia> m = ConditionalDists.SplitLayer(100,[100,100]) +julia> m = SplitLayer(100,[100,100]) julia> p = ConditionalMvNormal(m) julia> @time rand(p, rand(100,10000); julia> @time rand(p, rand(100,10000); @@ -26,6 +21,51 @@ julia> @time rand(p, rand(100,10000); 3.626042 seconds (159.97 k allocations: 18.681 GiB, 34.92% gc time) ``` +The mapping `m` must return a `Tuple` with mean and variance. +For a convenient way of doing this you can use a `SplitLayer`. + + +# Examples + +`ConditionalMvNormal` and `SplitLayer` together support 3 different variance +configurations: fixed/unit variance, shared variance, and trained variance. The +three different configurations are explained below. + +## Fixed/unit variance + +Pass a function to the `SplitLayer` that returns the fixed variance with +appropriate batch dimensions +```julia-repl +julia> σ(x::Vector) = 2 +julia> σ(x::Matrix) = ones(Float32,size(x,2)) .* 2 +julia> m = SplitLayer(Dense(2,3), σ) +julia> p = ConditionalMvNormal(m) +julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal +``` +Passing a mapping with a single output array assumes unit variance. + +## Shared variance + +For a learned variance that is the same across the the whole batch, simply pass +a vector (or scalar) to the `SplitLayer`. The `SplitLayer` wraps vectors/scalars +into a `TrainableVector`s/`TrainableScalar`s. +```julia-repl +julia> m = SplitLayer(Dense(2,3), ones(Float32,3)) +julia> p = ConditionalMvNormal(m) +julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringDiagMvNormal +``` + +## Trained variance + +Simply pass another trainable mapping for the variance. By just supplying input +sizes to `SplitLayer` you can automatically create `Dense` layers with given +activation functions. In this example the second activation function makes sure +that the variance is always positive +```julia-repl +julia> m = SplitLayer(2,[3,1],[identity,abs]) +julia> p = ConditionalMvNormal(m) +julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal +``` """ struct ConditionalMvNormal{Tm} <: AbstractConditionalDistribution mapping::Tm diff --git a/src/utils.jl b/src/utils.jl index a29a032..0d946c3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,8 +1,45 @@ +""" + SplitLayer(xs...) + +A layer that calls a number of sublayers/mappings with the same input and +returns a tuple of their outputs. Can be used in a regular Flux model: + +```julia-repl +julia> m = Chain(Dense(2,3), SplitLayer(Dense(3,2), x->x .* 2)) +julia> length(params(m)) == 4 +julia> (x,y) = m(rand(2)) +(Float32[-1.0541434, 1.1694773], Float32[-3.1472511, -0.86115724, -0.39665926]) +``` +Comes with a convenient constructor for a SplitLayer built from Dense layers +with given activation(s): +```julia-repl +julia> m = Chain(Dense(2,3), SplitLayer(3, [2,5], σ)) +julia> (x,y) = m(rand(2)) +(Float32[0.3069554, 0.3362006], Float32[0.437131, 0.4982477, 0.6465078, 0.4523438, 0.5068563]) +``` + +You can also provide just a vector / scalar that should be trained but have the +same value for all inputs (like a lonely bias vector). This functionality is +provided by the `TrainableVector`/`TrainableScalar` types. For vector inputs +they simply return the array they are wrapping. For matrix (i.e. batch) inputs +they return appropriately repeated arrays: +```julia-repl +julia> m = SplitLayer(Dense(2,3), ones(Float32,3)) +julia> length(params(m)) == 3 +julia> (x,y) = m(rand(2,5)) +julia> size(y) == (3,5) +julia> y +3×3 Array{Float32,2}: + 1.0 1.0 1.0 + 1.0 1.0 1.0 + 1.0 1.0 1.0 +``` +""" struct SplitLayer{T<:Tuple} layers::T end -SplitLayer(xs...) = SplitLayer(xs) +SplitLayer(layers...) = SplitLayer(map(maybe_trainable, layers)) function (m::SplitLayer)(x) Tuple(layer(x) for layer in m.layers) @@ -10,6 +47,30 @@ end @functor SplitLayer + +# for use as e.g. shared variance +struct TrainableVector{T<:AbstractArray} + v::T +end +(v::TrainableVector)(x::AbstractVector) = v.v +(v::TrainableVector)(x::AbstractMatrix) = v.v .* reshape(fillsimilar(v.v,size(x,ndims(x)),1),1,:) +(v::TrainableVector)() = v.v +@functor TrainableVector + +# for use as e.g. shared variance +struct TrainableScalar{T<:Real} + s::AbstractVector{T} + TrainableScalar{T}(x::T) where T<:Real = new{T}([x]) +end +TrainableScalar(x::T) where T<:Real = TrainableScalar{T}(x) +(s::TrainableScalar)(x::AbstractVector) = s.s[1] +(s::TrainableScalar)(x::AbstractMatrix) = fillsimilar(x,size(x,ndims(x)),1) .* s.s +@functor TrainableScalar + +maybe_trainable(x) = x +maybe_trainable(x::AbstractArray) = TrainableVector(x) +maybe_trainable(x::Real) = TrainableScalar(x) + fillsimilar(x::AbstractArray, s::Tuple, value::Real) = fill!(similar(x, s...), value) fillsimilar(x::AbstractArray, s, value::Real) = fill!(similar(x, s), value) @non_differentiable fillsimilar(::Any, ::Any, ::Any) diff --git a/test/cond_mvnormal.jl b/test/cond_mvnormal.jl index 2bbab3f..42b4c42 100644 --- a/test/cond_mvnormal.jl +++ b/test/cond_mvnormal.jl @@ -3,115 +3,84 @@ xlength = 3 zlength = 2 batchsize = 10 - m = SplitLayer(zlength, [xlength,xlength], [identity,abs]) - p = ConditionalMvNormal(m) |> gpu - # MvNormal - res = condition(p, rand(zlength) |> gpu) - μ = mean(res) - σ2 = var(res) - @test res isa TuringDiagMvNormal - @test size(μ) == (xlength,) - @test size(σ2) == (xlength,) + σvector(x::AbstractVector) = ones(Float32,xlength) .* 3 + σvector(x::AbstractMatrix) = ones(Float32,xlength,size(x,2)) .* 3 + σscalar(x::AbstractVector) = 2 + σscalar(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2 x = rand(Float32, xlength) |> gpu z = rand(Float32, zlength) |> gpu - loss() = logpdf(p,x,z) - ps = Flux.params(p) - @test_broken loss() isa Float32 - @test_nowarn Flux.gradient(loss, ps) - - f() = sum(rand(p,z)) - @test_broken Flux.gradient(f, ps) - - # BatchDiagMvNormal - res = condition(p, rand(zlength,batchsize)|>gpu) - μ = mean(res) - σ2 = var(res) - @test res isa ConditionalDists.BatchDiagMvNormal - @test size(μ) == (xlength,batchsize) - @test size(σ2) == (xlength,batchsize) - - x = rand(Float32, xlength, batchsize) |> gpu - z = rand(Float32, zlength, batchsize) |> gpu - loss() = sum(logpdf(p,x,z)) - ps = Flux.params(p) - @test length(ps) == 4 - @test loss() isa Float32 - @test_nowarn gs = Flux.gradient(loss, ps) - - f() = sum(rand(p,z)) - @test_nowarn Flux.gradient(f, ps) - - - # BatchScalMvNormal - m = SplitLayer(zlength, [xlength,1]) - p = ConditionalMvNormal(m) |> gpu - - res = condition(p, rand(zlength,batchsize)|>gpu) - μ = mean(res) - σ2 = var(res) - @test res isa ConditionalDists.BatchScalMvNormal - @test size(μ) == (xlength,batchsize) - @test size(σ2) == (xlength,batchsize) - - x = rand(Float32, xlength, batchsize) |> gpu - z = rand(Float32, zlength, batchsize) |> gpu - loss() = sum(logpdf(p,x,z)) - ps = Flux.params(p) - @test length(ps) == 4 - @test loss() isa Float32 - @test_nowarn gs = Flux.gradient(loss, ps) - - f() = sum(rand(p,z)) - @test_nowarn Flux.gradient(f, ps) - - - # Unit variance - m = Dense(zlength,xlength) - p = ConditionalMvNormal(m) |> gpu - - res = condition(p, rand(zlength,batchsize)|>gpu) - μ = mean(res) - σ2 = var(res) - @test res isa ConditionalDists.BatchScalMvNormal - @test size(μ) == (xlength,batchsize) - @test size(σ2) == (xlength,batchsize) - - x = rand(Float32, xlength, batchsize) |> gpu - z = rand(Float32, zlength, batchsize) |> gpu - loss() = sum(logpdf(p,x,z)) - ps = Flux.params(p) - @test length(ps) == 2 - @test loss() isa Float32 - @test_nowarn gs = Flux.gradient(loss, ps) - - f() = sum(rand(p,z)) - @test_nowarn Flux.gradient(f, ps) - - - # Fixed scalar variance - m = Dense(zlength,xlength) - σ(x::AbstractVector) = 2 - σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2 - p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu - - res = condition(p, rand(zlength,batchsize)|>gpu) - μ = mean(res) - σ2 = var(res) - @test res isa ConditionalDists.BatchScalMvNormal - @test size(μ) == (xlength,batchsize) - @test size(σ2) == (xlength,batchsize) - - x = rand(Float32, xlength, batchsize) |> gpu - z = rand(Float32, zlength, batchsize) |> gpu - loss() = sum(logpdf(p,x,z)) - ps = Flux.params(p) - @test length(ps) == 2 - @test loss() isa Float32 - @test_nowarn gs = Flux.gradient(loss, ps) - - f() = sum(rand(p,z)) - @test_nowarn Flux.gradient(f, ps) - + X = rand(Float32, xlength, batchsize) |> gpu + Z = rand(Float32, zlength, batchsize) |> gpu + + cases = [ + ("vector μ / vector σ", + SplitLayer(zlength, [xlength,xlength], [identity,abs]), Vector, 4), + ("vector μ / scalar σ", + SplitLayer(zlength, [xlength,1], [identity,abs]), Real, 4), + ("vector μ / fixed vector σ", + SplitLayer(Dense(zlength,xlength), σvector), Vector, 2), + ("vector μ / fixed scalar σ", + SplitLayer(Dense(zlength,xlength), σscalar), Real, 2), + ("vector μ / unit σ", + Dense(zlength,xlength), Real, 2), + ("vector μ / shared, trainable vector σ", + SplitLayer(Dense(zlength,xlength), ones(Float32,xlength)), Vector, 3), + ("vector μ / shared, trainable scalar σ", + SplitLayer(Dense(zlength,xlength), 1f0), Real, 3) + ] + + disttypes(::Type{<:Vector}) = (TuringDiagMvNormal,ConditionalDists.BatchDiagMvNormal) + disttypes(::Type{<:Real}) = (TuringScalMvNormal,ConditionalDists.BatchScalMvNormal) + σsize(::Type{<:Vector}) = (xlength,) + σsize(::Type{<:Real}) = () + + + for (name,mapping,T,nrps) in cases + @testset "$name" begin + p = ConditionalMvNormal(mapping) |> gpu + (Texample,Tbatch) = disttypes(T) + + res = condition(p,z) + μ = mean(res) + σ2 = var(res) + @test res isa Texample + @test size(μ) == (xlength,) + @test size(σ2) == σsize(T) + + loss() = logpdf(p,x,z) + ps = Flux.params(p) + @test length(ps) == nrps + @test loss() isa Float32 + @test_nowarn Flux.gradient(loss, ps) + + f() = sum(rand(p,z)) + gs = Flux.gradient(f,ps) + for p in ps + g = gs[p] + @test all(isfinite.(g)) && all(g .!= 0) + end + + + # batch tests + res = condition(p,Z) + μ = mean(res) + σ2 = var(res) + @test res isa Tbatch + @test size(μ) == (xlength,batchsize) + @test size(σ2) == (xlength,batchsize) + + loss() = sum(logpdf(p,X,Z)) + @test loss() isa Float32 + @test_nowarn Flux.gradient(loss, ps) + + f() = sum(rand(p,Z)) + gs = Flux.gradient(f,ps) + for p in ps + g = gs[p] + @test all(isfinite.(g)) && all(g .!= 0) + end + end + end end diff --git a/test/utils.jl b/test/utils.jl index 7006db3..562cfe4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,5 @@ @testset "SplitLayer" begin + # constant variance l = SplitLayer(x->x .+ 1, _->1) x = rand(3) (a,b) = l(x) @@ -9,4 +10,24 @@ (a,b) = l(x) @test size(a) == (2,) @test size(b) == (4,) + + # shared but learned variance (vector) + l = SplitLayer(x->x, ones(4)) + (a,b) = l(x) + @test size(a) == (3,) + @test size(b) == (4,) + + (a,b) = l(rand(3,10)) + @test size(a) == (3,10) + @test size(b) == (4,10) + + # shared but learned variance (scalar) + l = SplitLayer(x->x,1) + (a,b) = l(x) + @test size(a) == (3,) + @test size(b) == () + + (a,b) = l(rand(3,10)) + @test size(a) == (3,10) + @test size(b) == (10,) end