Skip to content


Shared variance via SplitLayer (#43)
Browse files Browse the repository at this point in the history
implement learned/shared/fixed/unit variance and properly test all cases
  • Loading branch information
nmheim authored Nov 1, 2020
1 parent e0cd6bc commit 5d0c37b
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 124 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ConditionalDists"
uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468"
authors = ["Niklas Heim <[email protected]>"]
version = "0.4.5"
version = "0.4.6"

ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
16 changes: 10 additions & 6 deletions
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

# ConditionalDists.jl

Conditional probability distributions powered by Flux.jl and Distributions.jl.
Conditional probability distributions powered by Flux.jl and DistributionsAD.jl.

The conditional PDFs that are defined in this package can be used in
conjunction with Flux models to provide trainable mappings. As an example,
assume you want to learn the mapping from a conditional to an MvNormal. The
mapping `m` takes a vector `x` and maps it to a mean `μ` and a variance `σ`,
which can be achieved by using a `ConditionalDists.SplitLayer` as the last
layer in your network like the one below: The `SplitLayer` is constructed from
`N` `Dense` layers (with same input size) and outputs `N` vectors:
layer in the network.
julia> m = SplitLayer(2, [3,4])
julia> m = Chain(Dense(2,2,σ), SplitLayer(2, [3,4]))
julia> m(rand(2))
(Float32[0.07946974, 0.13797458, 0.03939067], Float32[0.7006321, 0.37641272, 0.3586885, 0.82230335])
Expand All @@ -40,8 +39,13 @@ julia> z = rand(zlength, batchsize)
julia> logpdf(p,x,z)
julia> rand(p, randn(zlength, 10))
The trainable parameters (of the `SplitLayer`) are accessible as usual
through `Flux.params`. The next few lines show how to optimize `p` to match a
The trainable parameters (of the `SplitLayer`) are accessible as usual through
`Flux.params`. For different variance configurations (i.e. fixed/unit variance,
etc) check the doc strings with `julia>? ConditionalMvNormal`/`julia>?

The next few lines show how to optimize `p` to match a
given Gaussian by using the `kl_divergence` defined in

Expand Down
58 changes: 49 additions & 9 deletions src/cond_mvnormal.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
Specialization of ConditionalDistribution for `MvNormal`s for performance.
Does the same as ConditionalDistribution(MvNormal,m) for vector inputs (to e.g.
mean/logpdf). For batches of inputs a `BatchMvNormal` is constructed that does
Specialization of ConditionalDistribution for `MvNormal`s (for performance with
batches of inputs). Does the same as ConditionalDistribution(MvNormal,m)
but for batches of inputs a `BatchMvNormal` is constructed that does
not just map over the batch but uses faster matrix multiplications.
The mapping `m` must return either a `Tuple` with mean and variance, or just a
mean vector. If the output of `m` is just a vector, the variance is assumed to
be a fixed unit variance.
# Examples
julia> m = ConditionalDists.SplitLayer(100,[100,100])
julia> m = SplitLayer(100,[100,100])
julia> p = ConditionalMvNormal(m)
julia> @time rand(p, rand(100,10000);
julia> @time rand(p, rand(100,10000);
Expand All @@ -26,6 +21,51 @@ julia> @time rand(p, rand(100,10000);
3.626042 seconds (159.97 k allocations: 18.681 GiB, 34.92% gc time)
The mapping `m` must return a `Tuple` with mean and variance.
For a convenient way of doing this you can use a `SplitLayer`.
# Examples
`ConditionalMvNormal` and `SplitLayer` together support 3 different variance
configurations: fixed/unit variance, shared variance, and trained variance. The
three different configurations are explained below.
## Fixed/unit variance
Pass a function to the `SplitLayer` that returns the fixed variance with
appropriate batch dimensions
julia> σ(x::Vector) = 2
julia> σ(x::Matrix) = ones(Float32,size(x,2)) .* 2
julia> m = SplitLayer(Dense(2,3), σ)
julia> p = ConditionalMvNormal(m)
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal
Passing a mapping with a single output array assumes unit variance.
## Shared variance
For a learned variance that is the same across the the whole batch, simply pass
a vector (or scalar) to the `SplitLayer`. The `SplitLayer` wraps vectors/scalars
into a `TrainableVector`s/`TrainableScalar`s.
julia> m = SplitLayer(Dense(2,3), ones(Float32,3))
julia> p = ConditionalMvNormal(m)
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringDiagMvNormal
## Trained variance
Simply pass another trainable mapping for the variance. By just supplying input
sizes to `SplitLayer` you can automatically create `Dense` layers with given
activation functions. In this example the second activation function makes sure
that the variance is always positive
julia> m = SplitLayer(2,[3,1],[identity,abs])
julia> p = ConditionalMvNormal(m)
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal
struct ConditionalMvNormal{Tm} <: AbstractConditionalDistribution
Expand Down
63 changes: 62 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,76 @@
A layer that calls a number of sublayers/mappings with the same input and
returns a tuple of their outputs. Can be used in a regular Flux model:
julia> m = Chain(Dense(2,3), SplitLayer(Dense(3,2), x->x .* 2))
julia> length(params(m)) == 4
julia> (x,y) = m(rand(2))
(Float32[-1.0541434, 1.1694773], Float32[-3.1472511, -0.86115724, -0.39665926])
Comes with a convenient constructor for a SplitLayer built from Dense layers
with given activation(s):
julia> m = Chain(Dense(2,3), SplitLayer(3, [2,5], σ))
julia> (x,y) = m(rand(2))
(Float32[0.3069554, 0.3362006], Float32[0.437131, 0.4982477, 0.6465078, 0.4523438, 0.5068563])
You can also provide just a vector / scalar that should be trained but have the
same value for all inputs (like a lonely bias vector). This functionality is
provided by the `TrainableVector`/`TrainableScalar` types. For vector inputs
they simply return the array they are wrapping. For matrix (i.e. batch) inputs
they return appropriately repeated arrays:
julia> m = SplitLayer(Dense(2,3), ones(Float32,3))
julia> length(params(m)) == 3
julia> (x,y) = m(rand(2,5))
julia> size(y) == (3,5)
julia> y
3×3 Array{Float32,2}:
1.0 1.0 1.0
1.0 1.0 1.0
1.0 1.0 1.0
struct SplitLayer{T<:Tuple}

SplitLayer(xs...) = SplitLayer(xs)
SplitLayer(layers...) = SplitLayer(map(maybe_trainable, layers))

function (m::SplitLayer)(x)
Tuple(layer(x) for layer in m.layers)

@functor SplitLayer

# for use as e.g. shared variance
struct TrainableVector{T<:AbstractArray}
(v::TrainableVector)(x::AbstractVector) = v.v
(v::TrainableVector)(x::AbstractMatrix) = v.v .* reshape(fillsimilar(v.v,size(x,ndims(x)),1),1,:)
(v::TrainableVector)() = v.v
@functor TrainableVector

# for use as e.g. shared variance
struct TrainableScalar{T<:Real}
TrainableScalar{T}(x::T) where T<:Real = new{T}([x])
TrainableScalar(x::T) where T<:Real = TrainableScalar{T}(x)
(s::TrainableScalar)(x::AbstractVector) = s.s[1]
(s::TrainableScalar)(x::AbstractMatrix) = fillsimilar(x,size(x,ndims(x)),1) .* s.s
@functor TrainableScalar

maybe_trainable(x) = x
maybe_trainable(x::AbstractArray) = TrainableVector(x)
maybe_trainable(x::Real) = TrainableScalar(x)

fillsimilar(x::AbstractArray, s::Tuple, value::Real) = fill!(similar(x, s...), value)
fillsimilar(x::AbstractArray, s, value::Real) = fill!(similar(x, s), value)
@non_differentiable fillsimilar(::Any, ::Any, ::Any)
183 changes: 76 additions & 107 deletions test/cond_mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,115 +3,84 @@
xlength = 3
zlength = 2
batchsize = 10
m = SplitLayer(zlength, [xlength,xlength], [identity,abs])
p = ConditionalMvNormal(m) |> gpu

# MvNormal
res = condition(p, rand(zlength) |> gpu)
μ = mean(res)
σ2 = var(res)
@test res isa TuringDiagMvNormal
@test size(μ) == (xlength,)
@test size(σ2) == (xlength,)
σvector(x::AbstractVector) = ones(Float32,xlength) .* 3
σvector(x::AbstractMatrix) = ones(Float32,xlength,size(x,2)) .* 3
σscalar(x::AbstractVector) = 2
σscalar(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2

x = rand(Float32, xlength) |> gpu
z = rand(Float32, zlength) |> gpu
loss() = logpdf(p,x,z)
ps = Flux.params(p)
@test_broken loss() isa Float32
@test_nowarn Flux.gradient(loss, ps)

f() = sum(rand(p,z))
@test_broken Flux.gradient(f, ps)

# BatchDiagMvNormal
res = condition(p, rand(zlength,batchsize)|>gpu)
μ = mean(res)
σ2 = var(res)
@test res isa ConditionalDists.BatchDiagMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
loss() = sum(logpdf(p,x,z))
ps = Flux.params(p)
@test length(ps) == 4
@test loss() isa Float32
@test_nowarn gs = Flux.gradient(loss, ps)

f() = sum(rand(p,z))
@test_nowarn Flux.gradient(f, ps)

# BatchScalMvNormal
m = SplitLayer(zlength, [xlength,1])
p = ConditionalMvNormal(m) |> gpu

res = condition(p, rand(zlength,batchsize)|>gpu)
μ = mean(res)
σ2 = var(res)
@test res isa ConditionalDists.BatchScalMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
loss() = sum(logpdf(p,x,z))
ps = Flux.params(p)
@test length(ps) == 4
@test loss() isa Float32
@test_nowarn gs = Flux.gradient(loss, ps)

f() = sum(rand(p,z))
@test_nowarn Flux.gradient(f, ps)

# Unit variance
m = Dense(zlength,xlength)
p = ConditionalMvNormal(m) |> gpu

res = condition(p, rand(zlength,batchsize)|>gpu)
μ = mean(res)
σ2 = var(res)
@test res isa ConditionalDists.BatchScalMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
loss() = sum(logpdf(p,x,z))
ps = Flux.params(p)
@test length(ps) == 2
@test loss() isa Float32
@test_nowarn gs = Flux.gradient(loss, ps)

f() = sum(rand(p,z))
@test_nowarn Flux.gradient(f, ps)

# Fixed scalar variance
m = Dense(zlength,xlength)
σ(x::AbstractVector) = 2
σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2
p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu

res = condition(p, rand(zlength,batchsize)|>gpu)
μ = mean(res)
σ2 = var(res)
@test res isa ConditionalDists.BatchScalMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
loss() = sum(logpdf(p,x,z))
ps = Flux.params(p)
@test length(ps) == 2
@test loss() isa Float32
@test_nowarn gs = Flux.gradient(loss, ps)

f() = sum(rand(p,z))
@test_nowarn Flux.gradient(f, ps)

X = rand(Float32, xlength, batchsize) |> gpu
Z = rand(Float32, zlength, batchsize) |> gpu

cases = [
("vector μ / vector σ",
SplitLayer(zlength, [xlength,xlength], [identity,abs]), Vector, 4),
("vector μ / scalar σ",
SplitLayer(zlength, [xlength,1], [identity,abs]), Real, 4),
("vector μ / fixed vector σ",
SplitLayer(Dense(zlength,xlength), σvector), Vector, 2),
("vector μ / fixed scalar σ",
SplitLayer(Dense(zlength,xlength), σscalar), Real, 2),
("vector μ / unit σ",
Dense(zlength,xlength), Real, 2),
("vector μ / shared, trainable vector σ",
SplitLayer(Dense(zlength,xlength), ones(Float32,xlength)), Vector, 3),
("vector μ / shared, trainable scalar σ",
SplitLayer(Dense(zlength,xlength), 1f0), Real, 3)

disttypes(::Type{<:Vector}) = (TuringDiagMvNormal,ConditionalDists.BatchDiagMvNormal)
disttypes(::Type{<:Real}) = (TuringScalMvNormal,ConditionalDists.BatchScalMvNormal)
σsize(::Type{<:Vector}) = (xlength,)
σsize(::Type{<:Real}) = ()

for (name,mapping,T,nrps) in cases
@testset "$name" begin
p = ConditionalMvNormal(mapping) |> gpu
(Texample,Tbatch) = disttypes(T)

res = condition(p,z)
μ = mean(res)
σ2 = var(res)
@test res isa Texample
@test size(μ) == (xlength,)
@test size(σ2) == σsize(T)

loss() = logpdf(p,x,z)
ps = Flux.params(p)
@test length(ps) == nrps
@test loss() isa Float32
@test_nowarn Flux.gradient(loss, ps)

f() = sum(rand(p,z))
gs = Flux.gradient(f,ps)
for p in ps
g = gs[p]
@test all(isfinite.(g)) && all(g .!= 0)

# batch tests
res = condition(p,Z)
μ = mean(res)
σ2 = var(res)
@test res isa Tbatch
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (xlength,batchsize)

loss() = sum(logpdf(p,X,Z))
@test loss() isa Float32
@test_nowarn Flux.gradient(loss, ps)

f() = sum(rand(p,Z))
gs = Flux.gradient(f,ps)
for p in ps
g = gs[p]
@test all(isfinite.(g)) && all(g .!= 0)

2 comments on commit 5d0c37b

Copy link
Member Author

@nmheim nmheim commented on 5d0c37b Nov 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/24013

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.6 -m "<description of version>" 5d0c37bc119b16140e626811a92116b3e3260cf6
git push origin v0.4.6

Please sign in to comment.