Skip to content

Commit

Permalink
Merge pull request #4 from aicenter/dev_tpm_2023
Browse files Browse the repository at this point in the history
Fix minor bugs
  • Loading branch information
rektomar authored Aug 22, 2023
2 parents 6f5b5ee + 142e02a commit 27f5767
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/base_spsn/set_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ end
function knownsetmixture(μs, Σs, λs, prior)
components = map(zip(μs, Σs, λs)) do ps
pc = Poisson(log(ps[3]))
pf = MvNormalParams(ps[1], ps[2])
pf = MvNormalParams(Float32.(ps[1]), Float32.(ps[2]))
SetNode(pf, pc)
end
SumNode(components, prior)
end

### Creating and sampling known model
m1 = knownsetmixture(μ, Σ, λ, [1., 1., 1.])
m1 = knownsetmixture(μ, Σ, λ, [1f0, 1, 1])
nbags = 100
bags, baglabels = randwithlabel(m1, nbags)

Expand Down
4 changes: 2 additions & 2 deletions examples/base_spsn/set_clustering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ end
function knownsetmixture(μs, Σs, λs, prior)
components = map(zip(μs, Σs, λs)) do ps
pc = Poisson(log(ps[3]))
pf = MvNormalParams(ps[1], ps[2])
pf = MvNormalParams(Float32.(ps[1]), Float32.(ps[2]))
SetNode(pf, pc)
end
SumNode(components, prior)
end
### Creating and sampling known model
m1 = knownsetmixture(μ, Σ, λ, [1., 1., 1.])
m1 = knownsetmixture(μ, Σ, λ, [1f0, 1, 1])
nbags = 300
bags, baglabels = randwithlabel(m1, nbags)

Expand Down
38 changes: 34 additions & 4 deletions examples/mutagenesis/base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
"execution_count": 1,
"id": "bbb741d7",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mPrecompiling SumProductSet [d0366596-3556-49ae-b3ef-851ab4ad1106]\n"
]
}
],
"source": [
"using DrWatson\n",
"@quickactivate \n",
Expand Down Expand Up @@ -354,7 +362,7 @@
"Epoch 98 - acc: | 0.890 0.773 | \n",
"Epoch 99 - acc: | 0.890 0.773 | \n",
"Epoch 100 - acc: | 0.890 0.773 | \n",
"108.356103 seconds (247.62 M allocations: 59.257 GiB, 10.02% gc time, 60.86% compilation time)\n"
"159.992875 seconds (247.45 M allocations: 59.246 GiB, 8.70% gc time, 72.44% compilation time)\n"
]
}
],
Expand All @@ -372,6 +380,28 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "e18eaadc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1×100 Matrix{Float64}:\n",
" -189.62 -208.523 -248.18 -266.854 … -259.407 -218.131 -210.526"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logpdf(m, ds_train)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d5fa4303",
"metadata": {},
"outputs": [
Expand All @@ -381,7 +411,7 @@
"516"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -392,7 +422,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "f23baf1c",
"metadata": {},
"outputs": [
Expand Down
9 changes: 4 additions & 5 deletions src/SumProductSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@ include("modelnodes/modelnode.jl")
include("distributions/distributions.jl")
include("modelbuilders.jl")
include("util.jl")
include("loss.jl")
include("reflectinmodel.jl")

export ZIPoisson, Poisson, Geometric, Categorical, MvNormal, MvNormalParams
export Poisson, Geometric, Categorical, MvNormal, MvNormalParams
export logpdf, logjnt
export SumNode, ProductNode, SetNode
export rand, randwithlabel
export setmixture, gmm, sharedsetmixture, spn
export gmm, setmixture, sharedsetmixture, spn

export reflectinmodel
export em_loss, ce_loss, ssl_loss

export VAE, Encoder, Decoder, SplitLayer, elbo, reconstruct_loss
export em_loss, ce_loss

Base.show(io::IO, ::MIME"text/plain", n::AbstractModelNode) = HierarchicalUtils.printtree(io, n)

Expand Down
20 changes: 20 additions & 0 deletions src/loss.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

"""
em_loss(m::SumNode, xu)
Expetation maximization algorithm objective (multiplied by the minus sign) `for the top SumNode` layer for unlabeled data `xu``.
"""
function em_loss(m::SumNode, xu)
# E-step for unlabeled data
p = []
Flux.Zygote.ignore() do
p = softmax(logjnt(m, xu); dims=1)
end
# minus M-step objective
-mean(p .* logjnt(m, xu))
end

"""
ce_loss(m::SumNode, xl, yl::Vector{Int})
Cross entropy loss / negative log-likelihood loss for labeled data `xl` with corresponding labels `yl`.
"""
ce_loss(m::SumNode, xl, yl::Vector{Int}) = -mean(logjnt(m, xl)[CartesianIndex.(yl, 1:length(yl))])

0 comments on commit 27f5767

Please sign in to comment.