diff --git a/examples/base_spsn/set_classification.jl b/examples/base_spsn/set_classification.jl index 18a5ae2..652b772 100644 --- a/examples/base_spsn/set_classification.jl +++ b/examples/base_spsn/set_classification.jl @@ -29,14 +29,14 @@ end function knownsetmixture(μs, Σs, λs, prior) components = map(zip(μs, Σs, λs)) do ps pc = Poisson(log(ps[3])) - pf = MvNormalParams(ps[1], ps[2]) + pf = MvNormalParams(Float32.(ps[1]), Float32.(ps[2])) SetNode(pf, pc) end SumNode(components, prior) end ### Creating and sampling known model -m1 = knownsetmixture(μ, Σ, λ, [1., 1., 1.]) +m1 = knownsetmixture(μ, Σ, λ, [1f0, 1, 1]) nbags = 100 bags, baglabels = randwithlabel(m1, nbags) diff --git a/examples/base_spsn/set_clustering.jl b/examples/base_spsn/set_clustering.jl index 24f35e5..7d10d61 100644 --- a/examples/base_spsn/set_clustering.jl +++ b/examples/base_spsn/set_clustering.jl @@ -31,13 +31,13 @@ end function knownsetmixture(μs, Σs, λs, prior) components = map(zip(μs, Σs, λs)) do ps pc = Poisson(log(ps[3])) - pf = MvNormalParams(ps[1], ps[2]) + pf = MvNormalParams(Float32.(ps[1]), Float32.(ps[2])) SetNode(pf, pc) end SumNode(components, prior) end ### Creating and sampling known model -m1 = knownsetmixture(μ, Σ, λ, [1., 1., 1.]) +m1 = knownsetmixture(μ, Σ, λ, [1f0, 1, 1]) nbags = 300 bags, baglabels = randwithlabel(m1, nbags) diff --git a/examples/mutagenesis/base_model.ipynb b/examples/mutagenesis/base_model.ipynb index a396823..d44ed5e 100644 --- a/examples/mutagenesis/base_model.ipynb +++ b/examples/mutagenesis/base_model.ipynb @@ -5,7 +5,15 @@ "execution_count": 1, "id": "bbb741d7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mPrecompiling SumProductSet [d0366596-3556-49ae-b3ef-851ab4ad1106]\n" + ] + } + ], "source": [ "using DrWatson\n", "@quickactivate \n", @@ -354,7 +362,7 @@ "Epoch 98 - acc: | 0.890 0.773 | \n", "Epoch 99 - acc: | 0.890 0.773 | \n", "Epoch 100 - acc: | 0.890 0.773 | \n", - "108.356103 seconds (247.62 M allocations: 59.257 GiB, 10.02% gc time, 60.86% compilation time)\n" + "159.992875 seconds (247.45 M allocations: 59.246 GiB, 8.70% gc time, 72.44% compilation time)\n" ] } ], @@ -372,6 +380,28 @@ { "cell_type": "code", "execution_count": 10, + "id": "e18eaadc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1×100 Matrix{Float64}:\n", + " -189.62 -208.523 -248.18 -266.854 … -259.407 -218.131 -210.526" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logpdf(m, ds_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "id": "d5fa4303", "metadata": {}, "outputs": [ @@ -381,7 +411,7 @@ "516" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -392,7 +422,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "f23baf1c", "metadata": {}, "outputs": [ diff --git a/src/SumProductSet.jl b/src/SumProductSet.jl index 6f94583..3c6c7d1 100644 --- a/src/SumProductSet.jl +++ b/src/SumProductSet.jl @@ -21,18 +21,17 @@ include("modelnodes/modelnode.jl") include("distributions/distributions.jl") include("modelbuilders.jl") include("util.jl") +include("loss.jl") include("reflectinmodel.jl") -export ZIPoisson, Poisson, Geometric, Categorical, MvNormal, MvNormalParams +export Poisson, Geometric, Categorical, MvNormal, MvNormalParams export logpdf, logjnt export SumNode, ProductNode, SetNode export rand, randwithlabel -export setmixture, gmm, sharedsetmixture, spn +export gmm, setmixture, sharedsetmixture, spn export reflectinmodel -export em_loss, ce_loss, ssl_loss - -export VAE, Encoder, Decoder, SplitLayer, elbo, reconstruct_loss +export em_loss, ce_loss Base.show(io::IO, ::MIME"text/plain", n::AbstractModelNode) = HierarchicalUtils.printtree(io, n) diff --git a/src/loss.jl b/src/loss.jl new file mode 100644 index 0000000..99c6d1e --- /dev/null +++ b/src/loss.jl @@ -0,0 +1,20 @@ + +""" + em_loss(m::SumNode, xu) +Expetation maximization algorithm objective (multiplied by the minus sign) `for the top SumNode` layer for unlabeled data `xu``. +""" +function em_loss(m::SumNode, xu) + # E-step for unlabeled data + p = [] + Flux.Zygote.ignore() do + p = softmax(logjnt(m, xu); dims=1) + end + # minus M-step objective + -mean(p .* logjnt(m, xu)) +end + +""" + ce_loss(m::SumNode, xl, yl::Vector{Int}) +Cross entropy loss / negative log-likelihood loss for labeled data `xl` with corresponding labels `yl`. +""" +ce_loss(m::SumNode, xl, yl::Vector{Int}) = -mean(logjnt(m, xl)[CartesianIndex.(yl, 1:length(yl))])