From 7d76305a2934f0759e63ea0e3e729956182686c0 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Mon, 14 Aug 2023 16:43:00 +0200 Subject: [PATCH 01/40] Add ChainRulesCore --- Manifest.toml | 2 +- Project.toml | 1 + src/SumProductSet.jl | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Manifest.toml b/Manifest.toml index 8e7154a..b8e96ad 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.0" manifest_format = "2.0" -project_hash = "ca1c2813f0cd02f0f6821aca836d8a9a16fa6a33" +project_hash = "d6bdd6b622f8be80c61877539abf1791b1adbc12" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] diff --git a/Project.toml b/Project.toml index 9030227..b5b5841 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["rektomar "] version = "0.0.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HierarchicalUtils = "f9ccea15-0695-44b9-8113-df7c26ae4fa9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/SumProductSet.jl b/src/SumProductSet.jl index 0540175..e9c5663 100644 --- a/src/SumProductSet.jl +++ b/src/SumProductSet.jl @@ -1,6 +1,7 @@ module SumProductSet using Flux +using ChainRulesCore using NNlib using StatsBase using HierarchicalUtils From 0552ba68f6afd74a13307d0092915041e8a852da Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Mon, 14 Aug 2023 16:44:40 +0200 Subject: [PATCH 02/40] Improve Geo logpdf and add rrule --- src/distributions/geometric.jl | 60 ++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index 8c27f16..5f87f64 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -28,7 +28,7 @@ julia> logpdf(m, x) """ -mutable struct Geometric{T} <: Distribution +struct Geometric{T} <: Distribution logitp::Array{T, 1} end @@ -40,12 +40,60 @@ Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.01)*randn(dty # Functions for calculating the likelihood #### -_logpdf(m::Geometric, k) = _logpdf(m, SparseMatrixCSC(k)) -_logpdf(m::Geometric, k::SparseMatrixCSC) = k .*logsigmoid.(-m.logitp) .+ logsigmoid.(m.logitp) +# TODO: precompute logsigmoid.(m.logitp) +function _logpdf(m::Geometric{T}, x::SparseMatrixCSC) where {T<:Real} + ndims, nobs = size(x) + # linit = sum(logsigmoid, m.logitp) + linit = T(0e0) + @inbounds for r in eachindex(m.logitp) + linit += logsigmoid(m.logitp[r]) + end + l = fill(linit, 1, nobs) + + # I, J, K = findnz(x) + # @inbounds for n in 1:length(I) + # l[J[n]] += K[n]*logsigmoid(-m.logitp[I[n]]) + # end + for (i, j, k) in zip(findnz(x)...) + l[j] += k*logsigmoid(-m.logitp[i]) + end + l +end + +function _logpdf_back(logitp::Vector{T}, x, Δy) where {T<:Real} + ndims, nobs = size(x) + sum_Δy = sum(Δy) + + Δp = fill(sum_Δy, ndims) + @inbounds for r in eachindex(Δp) + Δp[r] *= sigmoid(-logitp[r]) + end + + for (i, j, k) in zip(findnz(x)...) + Δp[i] -= k*sigmoid(logitp[i])*Δy[j] + end + Δp +end + +function ChainRulesCore.rrule(::typeof(_logpdf), m::Geometric{T}, x::SparseMatrixCSC) where {T<:Real} + y = _logpdf(m, x) + p = m.logitp + function _logpdf_pullback(Δy) + Δlogitp = _logpdf_back(p, x, Δy) + Δm = Tangent{Geometric{T}}(; logitp=Δlogitp) + return NoTangent(), Δm, NoTangent() + end + return y, _logpdf_pullback +end + +logpdf(m::Geometric, x::SparseMatrixCSC) = _logpdf(m, x) +logpdf(m::Geometric, k::NGramMatrix) = _logpdf(m, SparseMatrixCSC(k)) + +# _logpdf(m::Geometric, k::SparseMatrixCSC) = k .*logsigmoid.(-m.logitp) .+ logsigmoid.(m.logitp) -logpdf(m::Geometric, x::NGramMatrix{T}) where {T<:Sequence} = sum(_logpdf(m, x); dims=1) -logpdf(m::Geometric{Tm}, x::NGramMatrix{Maybe{Tx}}) where {Tm<:Real, Tx<:Sequence} = sum(coalesce.(_logpdf(m, x), Tm(0e0)); dims=1) -logpdf(m::Geometric, x::SparseMatrixCSC) = sum(_logpdf(m, x); dims=1) +# logpdf(m::Geometric, x::NGramMatrix{T}) where {T<:Sequence} = sum(_logpdf(m, x); dims=1) +# logpdf(m::Geometric{Tm}, x::NGramMatrix{Maybe{Tx}}) where {Tm<:Real, Tx<:Sequence} = sum(coalesce.(_logpdf(m, x), Tm(0e0)); dims=1) +# logpdf(m::Geometric, x::SparseMatrixCSC) = sum(_logpdf(m, x); dims=1) #### # Functions for generating random samples From d8b32a22fa3983321698259f9892efc814e4f0a2 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 17 Aug 2023 14:23:15 +0200 Subject: [PATCH 03/40] Fix Zygote issue with Geometrix --- src/distributions/geometric.jl | 38 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index 5f87f64..3e95b87 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -41,21 +41,18 @@ Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.01)*randn(dty #### # TODO: precompute logsigmoid.(m.logitp) -function _logpdf(m::Geometric{T}, x::SparseMatrixCSC) where {T<:Real} + + +function _logpdf_geometric(logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} ndims, nobs = size(x) - # linit = sum(logsigmoid, m.logitp) - linit = T(0e0) - @inbounds for r in eachindex(m.logitp) - linit += logsigmoid(m.logitp[r]) + linit = T(0e0) # linit = sum(logsigmoid, m.logitp) + @inbounds for r in eachindex(logitp) + linit += logsigmoid(logitp[r]) end l = fill(linit, 1, nobs) - # I, J, K = findnz(x) - # @inbounds for n in 1:length(I) - # l[J[n]] += K[n]*logsigmoid(-m.logitp[I[n]]) - # end - for (i, j, k) in zip(findnz(x)...) - l[j] += k*logsigmoid(-m.logitp[i]) + @inbounds for (i, j, k) in zip(findnz(x)...) + l[j] += k*logsigmoid(-logitp[i]) end l end @@ -69,25 +66,20 @@ function _logpdf_back(logitp::Vector{T}, x, Δy) where {T<:Real} Δp[r] *= sigmoid(-logitp[r]) end - for (i, j, k) in zip(findnz(x)...) + @inbounds for (i, j, k) in zip(findnz(x)...) Δp[i] -= k*sigmoid(logitp[i])*Δy[j] end - Δp + Δp, NoTangent() end -function ChainRulesCore.rrule(::typeof(_logpdf), m::Geometric{T}, x::SparseMatrixCSC) where {T<:Real} - y = _logpdf(m, x) - p = m.logitp - function _logpdf_pullback(Δy) - Δlogitp = _logpdf_back(p, x, Δy) - Δm = Tangent{Geometric{T}}(; logitp=Δlogitp) - return NoTangent(), Δm, NoTangent() - end +function ChainRulesCore.rrule(::typeof(_logpdf_geometric), logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} + y = _logpdf_geometric(logitp, x) + _logpdf_pullback = Δy -> (NoTangent(), _logpdf_back(logitp, x, Δy)...) return y, _logpdf_pullback end -logpdf(m::Geometric, x::SparseMatrixCSC) = _logpdf(m, x) -logpdf(m::Geometric, k::NGramMatrix) = _logpdf(m, SparseMatrixCSC(k)) +logpdf(m::Geometric, x::SparseMatrixCSC) = _logpdf_geometric(m.logitp, x) +logpdf(m::Geometric, x::NGramMatrix) = logpdf(m, SparseMatrixCSC(x)) # _logpdf(m::Geometric, k::SparseMatrixCSC) = k .*logsigmoid.(-m.logitp) .+ logsigmoid.(m.logitp) From 3f0a57521be78f6cb65bcf3322499233ed8ac050 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Tue, 22 Aug 2023 11:38:00 +0200 Subject: [PATCH 04/40] Optimized SetNode, Geo, Pois --- src/distributions/geometric.jl | 21 +++++--------- src/distributions/poisson.jl | 51 +++++++++++++++++++++++++++++----- src/modelnodes/setnode.jl | 35 +++++++++++++++++++++-- 3 files changed, 84 insertions(+), 23 deletions(-) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index 3e95b87..8963ac0 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -44,13 +44,8 @@ Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.01)*randn(dty function _logpdf_geometric(logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} - ndims, nobs = size(x) - linit = T(0e0) # linit = sum(logsigmoid, m.logitp) - @inbounds for r in eachindex(logitp) - linit += logsigmoid(logitp[r]) - end - l = fill(linit, 1, nobs) - + linit = sum(logsigmoid, logitp) + l = fill(linit, 1, size(x, 2)) @inbounds for (i, j, k) in zip(findnz(x)...) l[j] += k*logsigmoid(-logitp[i]) end @@ -58,18 +53,16 @@ function _logpdf_geometric(logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real end function _logpdf_back(logitp::Vector{T}, x, Δy) where {T<:Real} - ndims, nobs = size(x) - sum_Δy = sum(Δy) + Δlogitp = fill(sum(Δy), size(x, 1)) - Δp = fill(sum_Δy, ndims) - @inbounds for r in eachindex(Δp) - Δp[r] *= sigmoid(-logitp[r]) + @inbounds for r in eachindex(Δlogitp) + Δlogitp[r] *= sigmoid(-logitp[r]) end @inbounds for (i, j, k) in zip(findnz(x)...) - Δp[i] -= k*sigmoid(logitp[i])*Δy[j] + Δlogitp[i] -= k*sigmoid(logitp[i])*Δy[j] end - Δp, NoTangent() + Δlogitp, NoTangent() end function ChainRulesCore.rrule(::typeof(_logpdf_geometric), logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} diff --git a/src/distributions/poisson.jl b/src/distributions/poisson.jl index 7739ba6..e731561 100644 --- a/src/distributions/poisson.jl +++ b/src/distributions/poisson.jl @@ -32,23 +32,60 @@ end Flux.@functor Poisson Poisson(lograte::AbstractFloat) = Poisson([lograte]) -Poisson(n::Int) = Poisson(Float64.(log.(rand(2:10, n)))) # pois_rand does not work with Float64 +Poisson(n::Int) = Poisson(Float64.(log.(rand(2:10, n)))) # pois_rand does not work with Float32 Poisson() = Poisson(1) #### # Functions for calculating the likelihood #### -_logpdf(lograte, x) = x .* lograte .- exp.(lograte) .- _logfactorial.(x) -logpdf(m::Poisson, x::Matrix{<:Real}) = sum(_logpdf(m.lograte, x), dims=1) -logpdf(m::Poisson, x::Vector{<:Real}) = sum(_logpdf(m.lograte, hcat(x...)), dims=1) -logpdf(m::Poisson, x::Real) = hcat(_logpdf(m.lograte, x)) # for consistency -logpdf(m::Poisson, x::SparseMatrixCSC) = sum(_logpdf(m.lograte, x), dims=1) + +function _logpdf_poisson(lograte::Vector{T}, x::Matrix{<:Real}) where {T<:Real} + ndims, nobs = size(x) + linit = -sum(exp, lograte) + + l = fill(linit, 1, nobs) + for j in 1:nobs + for i in 1:ndims + l[j] += x[i, j] * lograte[i] - logfactorial(x[i, j]) + end + end + l +end + +function _logpdf_poisson_back(lograte::Vector{T}, x, Δy) where {T<:Real} + ndims, nobs = size(x) + sum_Δy = sum(Δy) + Δp = zeros(T, ndims) + + @inbounds for i in 1:ndims + for j in 1:nobs + Δp[i] += x[i, j] * Δy[j] + end + Δp[i] -= sum_Δy * exp(lograte[i]) + end + Δp, NoTangent() +end + + +function ChainRulesCore.rrule(::typeof(_logpdf_poisson), args...) + _logpdf_pullback = Δy -> (NoTangent(), _logpdf_poisson_back(args..., Δy)...) + _logpdf_poisson(args...), _logpdf_pullback +end + +logpdf(m::Poisson, x::Matrix{<:Real}) = _logpdf_poisson(m.lograte, x) +logpdf(m::Poisson, x::Union{T, Vector{T}} where T<:Real) = logpdf(m, hcat(x...)) + +# old logpdfs +# _logpdf(lograte, x) = x .* lograte .- exp.(lograte) .- _logfactorial.(x) +# logpdf(m::Poisson, x::Matrix{<:Real}) = sum(_logpdf(m.lograte, x), dims=1) +# logpdf(m::Poisson, x::Vector{<:Real}) = sum(_logpdf(m.lograte, hcat(x...)), dims=1) +# logpdf(m::Poisson, x::Real) = hcat(_logpdf(m.lograte, x)) # for consistency +# logpdf(m::Poisson, x::SparseMatrixCSC) = sum(_logpdf(m.lograte, x), dims=1) #### # Functions for generating random samples #### -# Base.rand(m::Poisson, n::Int) = mapreduce(logλ -> map(_->pois_rand(exp(logλ)), 1:n)', vcat, m.logλ) Base.rand(m::Poisson, n::Int) = Mill.ArrayNode([pois_rand(exp(logr)) for logr in m.lograte, _ in 1:n]) Base.rand(m::Poisson) = rand(m, 1) diff --git a/src/modelnodes/setnode.jl b/src/modelnodes/setnode.jl index d465b08..d28848d 100644 --- a/src/modelnodes/setnode.jl +++ b/src/modelnodes/setnode.jl @@ -31,10 +31,41 @@ Flux.@functor SetNode #### function logpdf(m::SetNode, x::Mill.BagNode) - l = logpdf(m.feature, x.data) - mapreduce(b->logpdf(m.cardinality, length(b)) .+ sum(l[b]) .+ logfactorial(length(b)), hcat, x.bags.bags) + bags = x.bags.bags + logp_f = logpdf(m.feature, x.data) + logp_c = SumProductSet.logpdf(m.cardinality, hcat(length.(bags)...)) + # logp_c = ones(eltype(logp_f), 1, length(bags)) + _logpdf_set(logp_f, logp_c, bags) end +function _logpdf_set(logp_f, logp_c, bags) + lb = copy(logp_c) + @inbounds for (bi, b) in enumerate(bags) + lb[bi] += sum(logp_f[b]) + logfactorial(length(b)) + end + lb +end + +function _logpdf_set_back(logp_f, logp_c, bags, Δy) + Δlogp_f = zero(logp_f) + @inbounds for (bi, b) in enumerate(bags) + for i in b + Δlogp_f[i] += Δy[bi] + end + end + Δlogp_f, Δy, NoTangent() +end + +function ChainRulesCore.rrule(::typeof(_logpdf_set), args...) + _logpdf_set_pullback = Δy -> (NoTangent(), _logpdf_set_back(args..., Δy)...) + _logpdf_set(args...), _logpdf_set_pullback +end + +# function logpdf(m::SetNode, x::Mill.BagNode) +# l = logpdf(m.feature, x.data) +# mapreduce(b->logpdf(m.cardinality, length(b)) .+ sum(l[b]) .+ logfactorial(length(b)), hcat, x.bags.bags) +# end + #### # Functions for generating random samples #### From 3dffd21c5a3bbe8c1fe333e8ca2e37d914073912 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Tue, 22 Aug 2023 11:39:06 +0200 Subject: [PATCH 05/40] Update tests for optimized nodes --- test/Manifest.toml | 12 +++++++++--- test/Project.toml | 2 ++ test/distributions/geometric.jl | 11 +++++++++++ test/distributions/poisson.jl | 14 ++++++++++++-- test/modelbuilders.jl | 12 ++++++------ test/runtests.jl | 2 +- test/setnode.jl | 14 ++++++++++++++ 7 files changed, 55 insertions(+), 12 deletions(-) diff --git a/test/Manifest.toml b/test/Manifest.toml index 410f2bc..3c812c6 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.0" manifest_format = "2.0" -project_hash = "3c60c52f975d708cb23025b3b1afee7b35abe203" +project_hash = "003d8634b10a3893dd3a46cc8e6aa8dcf4b582cc" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -91,9 +91,15 @@ version = "1.44.7" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.6" +version = "1.16.0" + +[[deps.ChainRulesTestUtils]] +deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] +git-tree-sha1 = "5ab2a7bc21ecc3eb0226478ff8f87e9685b11818" +uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" +version = "1.11.0" [[deps.ChangesOfVariables]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] diff --git a/test/Project.toml b/test/Project.toml index ab39da8..5c9cfdc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Mill = "1d0525e4-8992-11e8-313c-e310e1f6ddea" diff --git a/test/distributions/geometric.jl b/test/distributions/geometric.jl index aa936db..bb206e3 100644 --- a/test/distributions/geometric.jl +++ b/test/distributions/geometric.jl @@ -21,6 +21,17 @@ end @test size(rand(m, nobs)) == (ndims, nobs) end +@testset "Geometric --- rrule test" begin + ndims = 10 + nobs = 100 + dtype = Float64 + m = SumProductSet.Geometric(ndims; dtype=dtype) + x = rand(m, nobs) + + test_rrule(SumProductSet._logpdf_geometric, m.logitp, x.data ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) +end + @testset "Geometric --- integration with Flux" begin ndims = 10 nobs = 100 diff --git a/test/distributions/poisson.jl b/test/distributions/poisson.jl index 18c9dfd..11d363c 100644 --- a/test/distributions/poisson.jl +++ b/test/distributions/poisson.jl @@ -2,7 +2,7 @@ @testset "Poisson --- logpdf forward" begin ndim = 1 m = SumProductSet.Poisson(ndim) - xs = [rand(0:20, ndim, 100), 2, 100., 0] + xs = [rand(0:20, ndim, 100), 2, 100, 0] for x in xs @test !isnothing(SumProductSet.logpdf(m, x)) @@ -21,9 +21,19 @@ end @test length(rand(m).data) == length(m.lograte) end +@testset "Poisson --- rrule test" begin + ndims = 10 + nobs = 100 + m = SumProductSet.Poisson(ndims) + x = rand(m, nobs) + + test_rrule(SumProductSet._logpdf_poisson, m.lograte, x.data ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) +end + @testset "Poisson --- integration with Flux" begin ndim = 10 - m = SumProductSet.Poisson() + m = SumProductSet.Poisson(ndim) ps = Flux.params(m) @test !isempty(ps) diff --git a/test/modelbuilders.jl b/test/modelbuilders.jl index 891b34f..92c8540 100644 --- a/test/modelbuilders.jl +++ b/test/modelbuilders.jl @@ -51,16 +51,16 @@ end @testset "hierarchical model -- logpdf forward" begin d1 = 9 d2 = 11 - pdist = ()->(SumProductSet.MvNormal(d1), SumProductSet.MvNormal(d2)) + pdist = ()->(:a=SumProductSet.MvNormal(d1), :b=SumProductSet.MvNormal(d2)) cdist = ()-> SumProductSet.Poisson() - prodmodel = ()->ProductNode(pdist()) - setmodel = ()->SetNode(prodmodel(), cdist()) + prodmodel = ()->SumProductSet.ProductNode(pdist()) + setmodel = ()->SumProductSet.SetNode(prodmodel(), cdist()) - m = SumNode([setmodel() for _ in 1:3]) + m = SumProductSet.SumNode([setmodel() for _ in 1:3]) - n = 30 - pn = Mill.ProductNode((randn(d1, n), randn(d2, n))) + nobs = 100 + pn = Mill.ProductNode(a=randn(Float32, d1, nobs), b=randn(Float32, d2, nobs)) bn = Mill.BagNode(pn, [1:5, 6:15, 16:16, 17:30]) @test !isnothing(SumProductSet.logpdf(m, bn)) diff --git a/test/runtests.jl b/test/runtests.jl index ab0d68a..dac36a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Test -using SumProductSet, Flux, Distributions, SparseArrays +using SumProductSet, Flux, Distributions, SparseArrays, ChainRulesCore, ChainRulesTestUtils import Mill include("distributions/mvnormal.jl") diff --git a/test/setnode.jl b/test/setnode.jl index 5707f7f..e115414 100644 --- a/test/setnode.jl +++ b/test/setnode.jl @@ -29,6 +29,20 @@ end @test Mill.numobs(rand(m, nobs)) == nobs end +@testset "SetNode --- rrule test" begin + ndims = 2 + nobs = 100 + dtype = Float64 + m = SumProductSet.SetNode(SumProductSet.MvNormal(ndims;dtype=dtype), SumProductSet.Poisson()) + x = rand(m, nobs) + bags = x.bags.bags + logp_f = SumProductSet.logpdf(m.feature, x.data) + logp_c = SumProductSet.logpdf(m.cardinality, hcat(length.(bags)...)) + + test_rrule(SumProductSet._logpdf_set, logp_f, logp_c, bags ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) +end + @testset "SetNode --- integration with Flux" begin ndims = 2 From 338e0d6887731e435fe7426a6ae70249c7d6d441 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Tue, 22 Aug 2023 16:35:17 +0200 Subject: [PATCH 06/40] Improve SetNode forward --- src/modelnodes/setnode.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/modelnodes/setnode.jl b/src/modelnodes/setnode.jl index d28848d..e181cff 100644 --- a/src/modelnodes/setnode.jl +++ b/src/modelnodes/setnode.jl @@ -34,14 +34,16 @@ function logpdf(m::SetNode, x::Mill.BagNode) bags = x.bags.bags logp_f = logpdf(m.feature, x.data) logp_c = SumProductSet.logpdf(m.cardinality, hcat(length.(bags)...)) - # logp_c = ones(eltype(logp_f), 1, length(bags)) _logpdf_set(logp_f, logp_c, bags) end function _logpdf_set(logp_f, logp_c, bags) lb = copy(logp_c) @inbounds for (bi, b) in enumerate(bags) - lb[bi] += sum(logp_f[b]) + logfactorial(length(b)) + for i in b + lb[bi] += logp_f[i] + end + lb[bi] += logfactorial(length(b)) end lb end From 0ac1291e0b14509b696970c6457e84ee958322de Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Tue, 22 Aug 2023 16:36:00 +0200 Subject: [PATCH 07/40] Fix Poisson type ambiguity --- src/distributions/poisson.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/distributions/poisson.jl b/src/distributions/poisson.jl index e731561..5791a69 100644 --- a/src/distributions/poisson.jl +++ b/src/distributions/poisson.jl @@ -32,7 +32,7 @@ end Flux.@functor Poisson Poisson(lograte::AbstractFloat) = Poisson([lograte]) -Poisson(n::Int) = Poisson(Float64.(log.(rand(2:10, n)))) # pois_rand does not work with Float32 +Poisson(n::Int) = Poisson(Float32.(log.(rand(2:10, n)))) Poisson() = Poisson(1) #### @@ -86,7 +86,8 @@ logpdf(m::Poisson, x::Union{T, Vector{T}} where T<:Real) = logpdf(m, hcat(x...)) # Functions for generating random samples #### -Base.rand(m::Poisson, n::Int) = Mill.ArrayNode([pois_rand(exp(logr)) for logr in m.lograte, _ in 1:n]) +# pois_rand does not work with Float32 +Base.rand(m::Poisson, n::Int) = Mill.ArrayNode([pois_rand(exp(Float64.(logr))) for logr in m.lograte, _ in 1:n]) Base.rand(m::Poisson) = rand(m, 1) #### From 4b77e4153a99dec23625b0308a140055ef55ca65 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Sat, 26 Aug 2023 22:50:21 +0200 Subject: [PATCH 08/40] Make Geo dist ecen faster --- src/distributions/geometric.jl | 23 ++++++++++++++++++----- test/distributions/geometric.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index 8963ac0..a204949 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -46,21 +46,34 @@ Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.01)*randn(dty function _logpdf_geometric(logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} linit = sum(logsigmoid, logitp) l = fill(linit, 1, size(x, 2)) - @inbounds for (i, j, k) in zip(findnz(x)...) - l[j] += k*logsigmoid(-logitp[i]) + + rows = rowvals(x) + vals = nonzeros(x) + @inbounds for j in 1:size(x, 2) + for i in nzrange(x, j) + row = rows[i] + val = vals[i] + l[j] += val * logsigmoid(-logitp[row]) + end end l end function _logpdf_back(logitp::Vector{T}, x, Δy) where {T<:Real} - Δlogitp = fill(sum(Δy), size(x, 1)) + Δlogitp = fill(sum(Δy), length(logitp)) @inbounds for r in eachindex(Δlogitp) Δlogitp[r] *= sigmoid(-logitp[r]) end - @inbounds for (i, j, k) in zip(findnz(x)...) - Δlogitp[i] -= k*sigmoid(logitp[i])*Δy[j] + rows = rowvals(x) + vals = nonzeros(x) + @inbounds for j in 1:size(x, 2) + for i in nzrange(x, j) + row = rows[i] + val = vals[i] + Δlogitp[row] -= val * sigmoid(logitp[row]) * Δy[j] + end end Δlogitp, NoTangent() end diff --git a/test/distributions/geometric.jl b/test/distributions/geometric.jl index bb206e3..a63bbc9 100644 --- a/test/distributions/geometric.jl +++ b/test/distributions/geometric.jl @@ -43,3 +43,30 @@ end @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, x)), ps)) end + +@testset "Geometric --- correctness" begin + p = 0.7 + logitp = log(p/(1-p)) + n = 100 + m1 = Distributions.Geometric(p) + m2 = SumProductSet.Geometric([logitp]) + x1 = rand(m1, n) + x2 = rand(m2, n) + + @test Distributions.logpdf.(m1, x1)[:] ≈ + SumProductSet.logpdf(m2, hcat(x1...) |> SparseMatrixCSC)[:] + @test Distributions.logpdf.(m1, x2.data |> Matrix)[:] ≈ + SumProductSet.logpdf(m2, x2)[:] + + + p1, p2 = 0.7, 0.1 + logitp1, logitp2 = log(p1/(1-p1)), log(p2/(1-p2)) + m3 = SumProductSet.Geometric([logitp1, logitp2]) + m31 = SumProductSet.Geometric([logitp1]) + m32 = SumProductSet.Geometric([logitp2]) + + x3 = rand(m3, n) + @test SumProductSet.logpdf(m31, x3.data[1:1, :]) + SumProductSet.logpdf(m32, x3.data[2:2, :]) ≈ + SumProductSet.logpdf(m3, x3) + +end From 7a8ef4692aa691cb359ca6aceabe7552a2476c76 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Sat, 26 Aug 2023 22:51:16 +0200 Subject: [PATCH 09/40] Add support for SparseArrays to Poisson --- src/distributions/poisson.jl | 48 ++++++++++++++++++++++++++++++++--- test/distributions/poisson.jl | 12 ++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/distributions/poisson.jl b/src/distributions/poisson.jl index 5791a69..31e44cf 100644 --- a/src/distributions/poisson.jl +++ b/src/distributions/poisson.jl @@ -32,8 +32,8 @@ end Flux.@functor Poisson Poisson(lograte::AbstractFloat) = Poisson([lograte]) -Poisson(n::Int) = Poisson(Float32.(log.(rand(2:10, n)))) -Poisson() = Poisson(1) +Poisson(n::Int; dtype::Type{<:Real}=Float32) = Poisson(dtype.(log.(rand(2:10, n)))) +Poisson(dtype::Type{<:Real}=Float32) = Poisson(1; dtype=dtype) #### # Functions for calculating the likelihood @@ -44,7 +44,7 @@ function _logpdf_poisson(lograte::Vector{T}, x::Matrix{<:Real}) where {T<:Real} linit = -sum(exp, lograte) l = fill(linit, 1, nobs) - for j in 1:nobs + @inbounds for j in 1:nobs for i in 1:ndims l[j] += x[i, j] * lograte[i] - logfactorial(x[i, j]) end @@ -52,7 +52,24 @@ function _logpdf_poisson(lograte::Vector{T}, x::Matrix{<:Real}) where {T<:Real} l end -function _logpdf_poisson_back(lograte::Vector{T}, x, Δy) where {T<:Real} + +function _logpdf_poisson(lograte::Vector{T}, x::SparseMatrixCSC) where {T<:Real} + linit = -sum(exp, lograte) + l = fill(linit, 1, size(x, 2)) + + rows = rowvals(x) + vals = nonzeros(x) + @inbounds for j in 1:size(x, 2) + for i in nzrange(x, j) + row = rows[i] + val = vals[i] + l[j] += val * lograte[row] - logfactorial(val) + end + end + l +end + +function _logpdf_poisson_back(lograte::Vector{T}, x::Matrix{<:Real}, Δy) where {T<:Real} ndims, nobs = size(x) sum_Δy = sum(Δy) Δp = zeros(T, ndims) @@ -66,6 +83,28 @@ function _logpdf_poisson_back(lograte::Vector{T}, x, Δy) where {T<:Real} Δp, NoTangent() end +function _logpdf_poisson_back(lograte::Vector{T}, x::SparseMatrixCSC, Δy) where {T<:Real} + ndims, nobs = size(x) + sum_Δy = sum(Δy) + Δp = zeros(T, ndims) + + @inbounds for i in eachindex(lograte) + Δp[i] -= sum_Δy * exp(lograte[i]) + end + + rows = rowvals(x) + vals = nonzeros(x) + @inbounds for j in 1:size(x, 2) + for i in nzrange(x, j) + row = rows[i] + val = vals[i] + Δp[row] += val * Δy[j] + end + end + + Δp, NoTangent() +end + function ChainRulesCore.rrule(::typeof(_logpdf_poisson), args...) _logpdf_pullback = Δy -> (NoTangent(), _logpdf_poisson_back(args..., Δy)...) @@ -73,6 +112,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf_poisson), args...) end logpdf(m::Poisson, x::Matrix{<:Real}) = _logpdf_poisson(m.lograte, x) +logpdf(m::Poisson, x::SparseMatrixCSC) = _logpdf_poisson(m.lograte, x) logpdf(m::Poisson, x::Union{T, Vector{T}} where T<:Real) = logpdf(m, hcat(x...)) # old logpdfs diff --git a/test/distributions/poisson.jl b/test/distributions/poisson.jl index 11d363c..9b81aa2 100644 --- a/test/distributions/poisson.jl +++ b/test/distributions/poisson.jl @@ -10,8 +10,9 @@ ndim = 10 m = SumProductSet.Poisson(ndim) - xs = rand(0:20, ndim, 100) - @test !isnothing(SumProductSet.logpdf(m, xs)) + x = rand(0:20, ndim, 100) + @test !isnothing(SumProductSet.logpdf(m, x)) + @test !isnothing(SumProductSet.logpdf(m, SparseMatrixCSC(x))) end @testset "Poisson --- rand sampling" begin @@ -24,11 +25,13 @@ end @testset "Poisson --- rrule test" begin ndims = 10 nobs = 100 - m = SumProductSet.Poisson(ndims) + m = SumProductSet.Poisson(ndims; dtype=Float64) x = rand(m, nobs) test_rrule(SumProductSet._logpdf_poisson, m.lograte, x.data ⊢ NoTangent(); check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) + test_rrule(SumProductSet._logpdf_poisson, m.lograte, SparseMatrixCSC(x.data) ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) end @testset "Poisson --- integration with Flux" begin @@ -39,6 +42,8 @@ end @test !isempty(ps) x = rand(0:20, ndim, 100) @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, x)), ps)) + xs = SparseMatrixCSC(x) + @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, xs)), ps)) end @testset "Poisson --- correctness" begin @@ -51,4 +56,5 @@ end @test Distributions.logpdf.(m1, x1)[:] ≈ SumProductSet.logpdf(m2, x1)[:] @test Distributions.logpdf.(m1, x2.data)[:] ≈ SumProductSet.logpdf(m2, x2)[:] + @test SumProductSet.logpdf(m2, x2.data) ≈ SumProductSet.logpdf(m2, SparseMatrixCSC(x2.data)) end \ No newline at end of file From 341a82e2c417d53c1e78e82118b9fcd1fc32bbb7 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Mon, 28 Aug 2023 14:53:03 +0200 Subject: [PATCH 10/40] Update typing in SetNode test --- test/setnode.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/setnode.jl b/test/setnode.jl index e115414..d0cb082 100644 --- a/test/setnode.jl +++ b/test/setnode.jl @@ -33,7 +33,7 @@ end ndims = 2 nobs = 100 dtype = Float64 - m = SumProductSet.SetNode(SumProductSet.MvNormal(ndims;dtype=dtype), SumProductSet.Poisson()) + m = SumProductSet.SetNode(SumProductSet.MvNormal(ndims;dtype=dtype), SumProductSet.Poisson(dtype)) x = rand(m, nobs) bags = x.bags.bags logp_f = SumProductSet.logpdf(m.feature, x.data) From bee396ab7b4228ff04b1677bca609bdb4bd84c78 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Wed, 30 Aug 2023 14:19:09 +0200 Subject: [PATCH 11/40] Optimize Categorical, add supp for missing vals --- src/distributions/categorical.jl | 69 +++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/src/distributions/categorical.jl b/src/distributions/categorical.jl index 3cf1a7a..8b31a78 100644 --- a/src/distributions/categorical.jl +++ b/src/distributions/categorical.jl @@ -25,7 +25,7 @@ julia> logpdf(m, x) """ struct Categorical{T} <: Distribution - logp::Array{T, 1} + logp::Vector{T} end Flux.@functor Categorical @@ -36,20 +36,67 @@ Categorical(n::Integer; dtype::Type{<:Real}=Float32) = Categorical(ones(dtype, n # Functions for calculating the likelihood #### -function logpdf(m::Categorical, x::Union{Int, Vector{Int}}) - logp = logsoftmax(m.logp) - logp[x] +# function logpdf(m::Categorical, x::Union{Int, Vector{Int}}) +# logp = logsoftmax(m.logp) +# logp[x] +# end + +# logpdf(m::Categorical, x::Matrix) = logpdf(m, vec(x)) +# logpdf(m::Categorical, x::Real) = logpdf(m, convert.(Int64, x)) +# logpdf(m::Categorical, x::Vector{<:Real}) = logpdf(m, convert.(Int64, x)) +# logpdf(m::Categorical, x::Matrix{<:Real}) = logpdf(m, convert.(Int64, vec(x))) + +# _logpdf(m::Categorical, x::OneHotArray) = reshape(logsoftmax(m.logp), 1, :) * x +# _logpdf(m::Categorical, x::MaybeHotArray) = PostImputingMatrix(reshape(logsoftmax(m.logp), 1, :)) * x + +# logpdf(m::Categorical, x::Union{OneHotArray, MaybeHotArray}) = _logpdf(m, x) + + +logpdf(m::Categorical, x::Union{OneHotArray, MaybeHotArray}) = _logpdf_cat(m.logp, x) + +_get_inidices(x::OneHotArray) = x.indices +_get_inidices(x::MaybeHotMatrix) = x.I +_get_inidices(x::MaybeHotVector) = x.i + +l_cat!(_ , _, _, idx::Missing) = nothing +l_cat!(l, logp, i, idx) = l[i] += logp[idx] + +function _logpdf_cat(logp::Vector{T}, x) where T + logp = logsoftmax(logp) + l = zeros(T, 1, size(x, 2)) + + @inbounds for (j, idx) in enumerate(_get_inidices(x)) + l_cat!(l, logp, j, idx) + end + l end -logpdf(m::Categorical, x::Matrix) = logpdf(m, vec(x)) -logpdf(m::Categorical, x::Real) = logpdf(m, convert.(Int64, x)) -logpdf(m::Categorical, x::Vector{<:Real}) = logpdf(m, convert.(Int64, x)) -logpdf(m::Categorical, x::Matrix{<:Real}) = logpdf(m, convert.(Int64, vec(x))) +Δlogp_cat!(_, _, idx::Missing, _) = nothing +Δlogp_cat!(Δlogp, Δy, idx, j) = Δlogp[idx] += Δy[j] + +_weight(i::Missing, _, ::Type{T}) where T = zero(T) +_weight(_, _, ::Type{T}) where T = one(T) + +function _logpdf_cat_back(logp::Vector{T}, x, Δy) where {T <: Real} + sum_Δy = zero(T) + Δlogp = zero(logp) -_logpdf(m::Categorical, x::OneHotArray) = reshape(logsoftmax(m.logp), 1, :) * x -_logpdf(m::Categorical, x::MaybeHotArray) = PostImputingMatrix(reshape(logsoftmax(m.logp), 1, :)) * x + @inbounds for (j, idx) in enumerate(_get_inidices(x)) + Δlogp_cat!(Δlogp, Δy, idx, j) + sum_Δy += _weight(idx, j, T) * Δy[j] + end + p = softmax(logp) + @inbounds for i in eachindex(logp) + Δlogp[i] -= sum_Δy * p[i] + end -logpdf(m::Categorical, x::Union{OneHotArray, MaybeHotArray}) = _logpdf(m, x) + Δlogp, NoTangent() +end + +function ChainRulesCore.rrule(::typeof(_logpdf_cat), args...) + _logpdf_cat_pullback = Δy -> (NoTangent(), _logpdf_cat_back(args..., Δy)...) + _logpdf_cat(args...), _logpdf_cat_pullback +end #### # Functions for generating random samples From 93b2044691b2d534881ad14066153609a4735e4a Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Wed, 30 Aug 2023 14:19:54 +0200 Subject: [PATCH 12/40] Update Categorical tests --- test/distributions/categorical.jl | 87 +++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 21 deletions(-) diff --git a/test/distributions/categorical.jl b/test/distributions/categorical.jl index 9b1f076..4e0ec2a 100644 --- a/test/distributions/categorical.jl +++ b/test/distributions/categorical.jl @@ -1,13 +1,13 @@ -@testset "Categorical --- logpdf forward" begin - ndims = 10 - m = SumProductSet.Categorical(ndims) - x1 = rand(1:ndims, 1, 100) - x2 = 6 - - @test length(SumProductSet.logpdf(m, x1)) == length(x1) - @test length(SumProductSet.logpdf(m, x2)) == length(x2) -end +# @testset "Categorical --- logpdf forward" begin +# ndims = 10 +# m = SumProductSet.Categorical(ndims) +# x1 = rand(1:ndims, 1, 100) +# x2 = 6 + +# @test length(SumProductSet.logpdf(m, x1)) == length(x1) +# @test length(SumProductSet.logpdf(m, x2)) == length(x2) +# end @testset "Categorical --- rand sampling" begin ncat = 10 @@ -19,30 +19,46 @@ end @test size(rand(m, nobs)) == (ncat, nobs) end -@testset "Categorical --- integration with Flux" begin - ndims = 10 - m = SumProductSet.Categorical(ndims) - ps = Flux.params(m) - - @test !isempty(ps) - x = rand(1:ndims, 1, 100) - @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, x)), ps)) -end - @testset "Categorical --- correctness" begin p = [0.1, 0.3, 0.15, 0.45] # sum(p)=1 + ncat = length(p) n = 100 m1 = Distributions.Categorical(p) m2 = SumProductSet.Categorical(log.(p)) x1 = rand(m1, n) x2 = rand(m2, n) - @test Distributions.logpdf.(m1, x1)[:] ≈ SumProductSet.logpdf(m2, x1)[:] + @test Distributions.logpdf.(m1, x1)[:] ≈ SumProductSet.logpdf(m2, Flux.onehotbatch(x1, 1:ncat))[:] @test Distributions.logpdf.(m1, Flux.onecold(x2.data))[:] ≈ SumProductSet.logpdf(m2, x2)[:] end +@testset "Categorical --- rrule test" begin + nobs = 100 + ncat = 10 + m = SumProductSet.Categorical(randn(Float64, ncat)) + x = rand(m, nobs) + xm = rand(1:ncat, nobs) + xm = Mill.maybehotbatch([missing; xm; missing], 1:ncat) + + test_rrule(SumProductSet._logpdf_cat, m.logp, x.data ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) + test_rrule(SumProductSet._logpdf_cat, m.logp, xm ⊢ NoTangent(); + check_inferred=true, rtol = 1.0e-9, atol = 1.0e-9) +end + +@testset "Categorical --- integration with Flux" begin + nobs = 100 + ncat = 10 + m = SumProductSet.Categorical(ncat) + ps = Flux.params(m) + + @test !isempty(ps) + x = rand(m, nobs) + @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, x)), ps)) +end + @testset "Categorical --- integration with OneHotArrays" begin - nobs = 20 + nobs = 100 ncat = 10 m = SumProductSet.Categorical(ncat) ps = Flux.params(m) @@ -55,3 +71,32 @@ end @test !isnothing(SumProductSet.logpdf(m, x_oh)) @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, x_oh)), ps)) end + +@testset "Categorical --- integration with MaybeHotArrays" begin + nobs = 100 + ncat = 10 + m = SumProductSet.Categorical(ncat) + ps = Flux.params(m) + + xm = rand(1:ncat, nobs) + xm = Mill.maybehotbatch([missing; xm; missing], 1:ncat) + + @test !isnothing(SumProductSet.logpdf(m, xm)) + @test !isnothing(gradient(() -> sum(SumProductSet.logpdf(m, xm)), ps)) +end + + +@testset "Categorical --- marginalization" begin + nobs = 100 + ncat = 10 + m = SumProductSet.Categorical(ncat) + ps = Flux.params(m) + + xm = rand(1:ncat, nobs) + xm1 = Mill.maybehotbatch(xm, 1:ncat) + xm2 = Mill.maybehotbatch([missing; xm; missing], 1:ncat) + + @test sum(SumProductSet.logpdf(m, xm1)) == sum(SumProductSet.logpdf(m, xm2)) + @test gradient(() -> sum(SumProductSet.logpdf(m, xm1)), ps).grads == + gradient(() -> sum(SumProductSet.logpdf(m, xm2)), ps).grads +end \ No newline at end of file From 725088a462767e5d78d172680c2fb5114adf5176 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:42:00 +0200 Subject: [PATCH 13/40] Optimize Categorical distribution --- src/distributions/categorical.jl | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/distributions/categorical.jl b/src/distributions/categorical.jl index 8b31a78..47c84ee 100644 --- a/src/distributions/categorical.jl +++ b/src/distributions/categorical.jl @@ -36,23 +36,8 @@ Categorical(n::Integer; dtype::Type{<:Real}=Float32) = Categorical(ones(dtype, n # Functions for calculating the likelihood #### -# function logpdf(m::Categorical, x::Union{Int, Vector{Int}}) -# logp = logsoftmax(m.logp) -# logp[x] -# end - -# logpdf(m::Categorical, x::Matrix) = logpdf(m, vec(x)) -# logpdf(m::Categorical, x::Real) = logpdf(m, convert.(Int64, x)) -# logpdf(m::Categorical, x::Vector{<:Real}) = logpdf(m, convert.(Int64, x)) -# logpdf(m::Categorical, x::Matrix{<:Real}) = logpdf(m, convert.(Int64, vec(x))) - -# _logpdf(m::Categorical, x::OneHotArray) = reshape(logsoftmax(m.logp), 1, :) * x -# _logpdf(m::Categorical, x::MaybeHotArray) = PostImputingMatrix(reshape(logsoftmax(m.logp), 1, :)) * x - -# logpdf(m::Categorical, x::Union{OneHotArray, MaybeHotArray}) = _logpdf(m, x) - - logpdf(m::Categorical, x::Union{OneHotArray, MaybeHotArray}) = _logpdf_cat(m.logp, x) +lgpdf(m::Categorical{T}, x::MaybeHotMatrix{Missing}) where {T<:Real} = zeros(T, 1, size(x, 2)) _get_inidices(x::OneHotArray) = x.indices _get_inidices(x::MaybeHotMatrix) = x.I From 2ba9fad32df1e8bab91acaca275ee0ee638a6a42 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:42:35 +0200 Subject: [PATCH 14/40] Add support for full missing data in MvNormal --- src/distributions/mvnormal.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/distributions/mvnormal.jl b/src/distributions/mvnormal.jl index 197ace0..b4b35f8 100644 --- a/src/distributions/mvnormal.jl +++ b/src/distributions/mvnormal.jl @@ -55,10 +55,12 @@ _logpdf(x::Union{Array{T, 2}, Array{Maybe{T}, 2}}) where {T<:Real} = -T(5e-1)*(l logpdf(m::MvNormal{T, 2}, x::Array{T, 2}) where {T<:Real} = log(abs(det(m.A + m.r * I))) .+ sum(_logpdf((m.A + m.r * I) * x .+ m.b), dims=1) logpdf(m::MvNormal{T, 1}, x::Array{T, 2}) where {T<:Real} = sum(log.(abs.(m.A .+ m.r))) .+ sum(_logpdf((m.A .+ m.r) .* x .+ m.b), dims=1) logpdf(m::MvNormal{T, 1}, x::Array{Maybe{T}, 2}) where {T<:Real} = sum(coalesce.(log.(abs.(m.A .+ m.r)) .+ _logpdf((m.A .+ m.r) .* x .+ m.b), T(0e0)); dims=1) +logpdf(m::MvNormal{T, 1}, x::Array{Missing, 2}) where {T<:Real} = zeros(T, 1, size(x, 2)) logpdf(m::MvNormal{T, 2}, x::Array{T, 1}) where {T<:Real} = logpdf(m, hcat(x)) logpdf(m::MvNormal{T, 1}, x::Array{T, 1}) where {T<:Real} = logpdf(m, hcat(x)) logpdf(m::MvNormal{T, 1}, x::Array{Maybe{T}, 1}) where {T<:Real} = logpdf(m, hcat(x)) +logpdf(m::MvNormal{T, 1}, x::Array{Missing, 1}) where {T<:Real} = zeros(T, 1, 1) #### # Functions for generating random samples From 318056ee5ca67f1107b2d049802555c988c92ff3 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:46:47 +0200 Subject: [PATCH 15/40] Add straight support for NgramM in Geometric --- src/distributions/geometric.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index a204949..621dc6c 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -34,7 +34,7 @@ end Flux.@functor Geometric -Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.01)*randn(dtype, n)) +Geometric(n::Int; dtype::Type{<:Real}=Float32) = Geometric(dtype(0.1)*randn(dtype, n)) #### # Functions for calculating the likelihood @@ -78,6 +78,20 @@ function _logpdf_back(logitp::Vector{T}, x, Δy) where {T<:Real} Δlogitp, NoTangent() end +function _logpdf2_geo(logitp::Vector, x::NGramMatrix) + + linit = sum(logsigmoid, logitp) + l = fill(linit, 1, size(x, 2)) + + mlogp = logsigmoid.(-logitp) # unnecessay memory allocation, saves computing time + @inbounds for j in 1:size(x, 2) + for i in NGramIterator(x, j) + l[j] += mlogp[i+1] + end + end + l +end + function ChainRulesCore.rrule(::typeof(_logpdf_geometric), logitp::Vector{T}, x::SparseMatrixCSC) where {T<:Real} y = _logpdf_geometric(logitp, x) _logpdf_pullback = Δy -> (NoTangent(), _logpdf_back(logitp, x, Δy)...) From ee5fa94e7bf0baefba8052a61cc6fa27c82b37c9 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:47:27 +0200 Subject: [PATCH 16/40] Update reflectinmodel to accept dtype --- src/reflectinmodel.jl | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/reflectinmodel.jl b/src/reflectinmodel.jl index 912b299..e19b7d2 100644 --- a/src/reflectinmodel.jl +++ b/src/reflectinmodel.jl @@ -36,10 +36,10 @@ function reflectinmodel( homo_ns::Int = 1, hete_nl::Int = 1, hete_ns::Int = 1, - dist_cont = d->gmm(2, d), - dist_disc = d->Categorical(d), - dist_gram = d->Geometric(d), - dist_card = ()->Poisson(), + dist_cont = (d, dtype)->gmm(2, d; dtype=dtype), + dist_disc = (d, dtype)->Categorical(d; dtype=dtype), + dist_gram = (d, dtype)->Geometric(d; dtype=dtype), + dist_card = (dtype)->Poisson(dtype), data_type::Type{<:Real} = Float32, seed::Int=1 ) @@ -48,7 +48,7 @@ function reflectinmodel( Random.seed!(seed) - root_ns > 1 ? SumNode(map(_->_reflectinmodel(x, settings), 1:root_ns)) : _reflectinmodel(x, settings) + root_ns > 1 ? SumNode(map(_->_reflectinmodel(x, settings), 1:root_ns); dtype=settings.data_type) : _reflectinmodel(x, settings) end function _reflectinmodel(x::Mill.ProductNode, settings::ModelSettings) @@ -60,25 +60,27 @@ function _reflectinmodel(x::Mill.ProductNode, settings::ModelSettings) end function _reflectinmodel(x::Mill.BagNode, settings::ModelSettings) if settings.homo_ns == 1 - SetNode(_reflectinmodel(x.data, settings), settings.dist_card()) + SetNode(_reflectinmodel(x.data, settings), settings.dist_card(settings.data_type)) else - SumNode(map(_->SetNode(_reflectinmodel(x.data, settings), settings.dist_card()), 1:settings.homo_ns)) + SumNode( + map(_->SetNode(_reflectinmodel(x.data, settings), settings.dist_card(settings.data_type)), 1:settings.homo_ns); + dtype=settings.data_type) end end _reflectinmodel(x::Mill.ArrayNode, settings::ModelSettings) = _reflectinmodel(x.data, settings) -_reflectinmodel(x::OneHotArray, settings) = settings.dist_disc(size(x, 1)) -_reflectinmodel(x::MaybeHotArray, settings) = settings.dist_disc(size(x, 1)) -_reflectinmodel(x::Array{T}, settings) where T <: Real = settings.dist_cont(size(x, 1)) -_reflectinmodel(x::Array{Maybe{T}}, settings) where T <: Real = settings.dist_cont(size(x, 1)) -_reflectinmodel(x::NGramMatrix{T}, settings) where T <: Sequence = settings.dist_gram(size(x, 1)) -_reflectinmodel(x::NGramMatrix{Maybe{T}}, settings) where T <: Sequence = settings.dist_gram(size(x, 1)) +_reflectinmodel(x::OneHotArray, settings) = settings.dist_disc(size(x, 1), settings.data_type) +_reflectinmodel(x::MaybeHotArray, settings) = settings.dist_disc(size(x, 1), settings.data_type) +_reflectinmodel(x::Array{T}, settings) where T <: Real = settings.dist_cont(size(x, 1), settings.data_type) +_reflectinmodel(x::Array{Maybe{T}}, settings) where T <: Real = settings.dist_cont(size(x, 1), settings.data_type) +_reflectinmodel(x::NGramMatrix{T}, settings) where T <: Sequence = settings.dist_gram(size(x, 1), settings.data_type) +_reflectinmodel(x::NGramMatrix{Maybe{T}}, settings) where T <: Sequence = settings.dist_gram(size(x, 1), settings.data_type) function _productmodel(x, n::Int, settings::ModelSettings) k = keys(x.data) c = map(_->ProductNode(mapreduce(k->_reflectinmodel(x.data[k], settings), vcat, k), reduce(vcat, k)), 1:n) - n == 1 ? first(c) : SumNode(c) + n == 1 ? first(c) : SumNode(c; dtype=settings.data_type) end function _productmodel(x, scope::Vector{Symbol}, l::Int, n::Int, settings::ModelSettings) where N d = length(scope) @@ -92,5 +94,5 @@ function _productmodel(x, scope::Vector{Symbol}, l::Int, n::Int, settings::Model comps_r = _productmodel(x[scope_r], scope_r, l-1, n, settings) ProductNode([comps_l, comps_r], [scope_l, scope_r]) end - n == 1 ? first(c) : SumNode(c) + n == 1 ? first(c) : SumNode(c; dtype=settings.data_type) end From fae6fd82d007329608a006acbbd08b99e4ee85d8 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:48:04 +0200 Subject: [PATCH 17/40] Add dummy distribution --- src/distributions/dummy.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 src/distributions/dummy.jl diff --git a/src/distributions/dummy.jl b/src/distributions/dummy.jl new file mode 100644 index 0000000..b7797cd --- /dev/null +++ b/src/distributions/dummy.jl @@ -0,0 +1,10 @@ +struct Dummy{T<:Real} <: Distribution end + +Dummy(_; dtype::Type{<:Real}=Float32) = Dummy{dtype} +Dummy(dtype::Type{<:Real}=Float32) = Dummy{dtype} + +#### +# Functions for calculating the likelihood +#### + +logpdf(m::Dummy{T}, x::AbstractMatrix) where T = zeros(T, size(x, 2)) From 9ba24a37c1ab86368aefef4f351707c74f86cb67 Mon Sep 17 00:00:00 2001 From: Martin Rektoris Date: Thu, 7 Sep 2023 21:48:34 +0200 Subject: [PATCH 18/40] Update mutagenesis example in the light of recent changes --- examples/mutagenesis/Manifest.toml | 6 +- examples/mutagenesis/Project.toml | 5 - examples/mutagenesis/base_model.html | 284 ++++++++++---------------- examples/mutagenesis/base_model.ipynb | 263 ++++++++++-------------- 4 files changed, 223 insertions(+), 335 deletions(-) diff --git a/examples/mutagenesis/Manifest.toml b/examples/mutagenesis/Manifest.toml index 7a62036..d55f792 100644 --- a/examples/mutagenesis/Manifest.toml +++ b/examples/mutagenesis/Manifest.toml @@ -834,9 +834,9 @@ version = "1.8.0" [[deps.PoissonRandom]] deps = ["Random"] -git-tree-sha1 = "45f9da1ceee5078267eb273d065e8aa2f2515790" +git-tree-sha1 = "a0f1159c33f846aa77c3f30ebbc69795e5327152" uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab" -version = "0.4.3" +version = "0.4.4" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -1037,7 +1037,7 @@ uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" version = "1.10.0" [[deps.SumProductSet]] -deps = ["Flux", "HierarchicalUtils", "LinearAlgebra", "Mill", "NNlib", "OneHotArrays", "PoissonRandom", "Random", "SparseArrays", "SpecialFunctions", "StatsBase"] +deps = ["ChainRulesCore", "Flux", "HierarchicalUtils", "LinearAlgebra", "Mill", "NNlib", "OneHotArrays", "PoissonRandom", "Random", "SparseArrays", "SpecialFunctions", "StatsBase"] path = "../.." uuid = "d0366596-3556-49ae-b3ef-851ab4ad1106" version = "0.0.0" diff --git a/examples/mutagenesis/Project.toml b/examples/mutagenesis/Project.toml index d0c048f..2cfe6fd 100644 --- a/examples/mutagenesis/Project.toml +++ b/examples/mutagenesis/Project.toml @@ -1,8 +1,3 @@ -name = "mutagenesis" -uuid = "72acebe5-962d-406a-88dd-574473246d5e" -authors = ["Martin Rektoris "] -version = "0.1.0" - [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HierarchicalUtils = "f9ccea15-0695-44b9-8113-df7c26ae4fa9" diff --git a/examples/mutagenesis/base_model.html b/examples/mutagenesis/base_model.html index 2130b84..d24bf91 100644 --- a/examples/mutagenesis/base_model.html +++ b/examples/mutagenesis/base_model.html @@ -14611,68 +14611,6 @@ - -