Skip to content

Commit

Permalink
BatchMvNormal variance shape (#32)
Browse files Browse the repository at this point in the history
* let BatchMvNormal always return variance of shape (xlength,batchsize)
  • Loading branch information
nmheim authored Aug 28, 2020
1 parent ace6c63 commit d27b840
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ConditionalDists"
uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468"
authors = ["Niklas Heim <[email protected]>"]
version = "0.4.1"
version = "0.4.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
15 changes: 3 additions & 12 deletions src/batch_mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ BatchMvNormal(μ::AbstractMatrix{T}, σ::AbstractMatrix{T}) where T<:Real = Batc
Base.eltype(d::BMN) = eltype(d.μ)
Distributions.params(d::BMN) = (d.μ, d.σ)
Distributions.mean(d::BMN) = d.μ
Distributions.var(d::BMN) = d.σ .^2

#Distributions.var(d::BatchScalMvNormal) = fill(similar(d.σ,size(d.μ,1)),1) .* reshape(d.σ .^2,1,:)
#Distributions.var(d::BatchScalMvNormal) = ones(eltype(d), size(d.μ,1), 1) .* reshape(d.σ .^2, 1, :)
Distributions.var(d::BatchDiagMvNormal) = d.σ .^2
Distributions.var(d::BatchScalMvNormal) = fillsimilar(d.σ,size(d.μ,1),1) .* reshape(d.σ .^2,1,:)

function Distributions.rand(d::BatchDiagMvNormal)
μ, σ = d.μ, d.σ
Expand All @@ -34,16 +32,9 @@ function Distributions.rand(d::BatchScalMvNormal)
μ .+ σ .* r
end

function Distributions.logpdf(d::BatchDiagMvNormal, x::AbstractMatrix{T}) where T<:Real
function Distributions.logpdf(d::BMN, x::AbstractMatrix{T}) where T<:Real
n = size(d.μ,1)
μ = mean(d)
σ2 = var(d)
-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
end

function Distributions.logpdf(d::BatchScalMvNormal, x::AbstractMatrix{T}) where T<:Real
n = size(d.μ,1)
μ = mean(d)
σ2 = reshape(var(d), 1, :)
-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
end
4 changes: 2 additions & 2 deletions test/cond_mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
σ2 = var(res)
@test res isa ConditionalDists.BatchScalMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (batchsize,)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
Expand All @@ -78,7 +78,7 @@
σ2 = var(res)
@test res isa ConditionalDists.BatchScalMvNormal
@test size(μ) == (xlength,batchsize)
@test size(σ2) == (batchsize,)
@test size(σ2) == (xlength,batchsize)

x = rand(Float32, xlength, batchsize) |> gpu
z = rand(Float32, zlength, batchsize) |> gpu
Expand Down

2 comments on commit d27b840

@nmheim
Copy link
Member Author

@nmheim nmheim commented on d27b840 Aug 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/20447

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.2 -m "<description of version>" d27b840e3d457007c679aa979352d6f6f6e4ef38
git push origin v0.4.2

Please sign in to comment.