From 036a24d88cc0fb1033680055b69b23e2784f3b55 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 25 Jun 2024 19:30:17 -0700 Subject: [PATCH 1/3] Add mean et al. for truncated log normal Fixes 709 --- src/truncate.jl | 1 + src/truncated/lognormal.jl | 50 +++++++++++++++++++++++++++++++++++++ src/truncated/normal.jl | 30 ++++++++++++++++++++++ test/runtests.jl | 1 + test/testutils.jl | 11 ++++++++ test/truncated/lognormal.jl | 36 ++++++++++++++++++++++++++ test/truncated/normal.jl | 34 +++++++++++++++++++++++++ 7 files changed, 163 insertions(+) create mode 100644 src/truncated/lognormal.jl create mode 100644 test/truncated/lognormal.jl diff --git a/src/truncate.jl b/src/truncate.jl index 48d62b015..ca04d5d7a 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -261,6 +261,7 @@ include(joinpath("truncated", "exponential.jl")) include(joinpath("truncated", "uniform.jl")) include(joinpath("truncated", "loguniform.jl")) include(joinpath("truncated", "discrete_uniform.jl")) +include(joinpath("truncated", "lognormal.jl")) #### Utilities diff --git a/src/truncated/lognormal.jl b/src/truncated/lognormal.jl new file mode 100644 index 000000000..03d15930c --- /dev/null +++ b/src/truncated/lognormal.jl @@ -0,0 +1,50 @@ +# Moments of the truncated log-normal can be computed directly from the moment generating +# function of the truncated normal: +# Let Y ~ LogNormal(μ, σ) truncated to (a, b). Then log(Y) ~ Normal(μ, σ) truncated +# to (log(a), log(b)), and E[Y^n] = E[(e^log(Y))^n] = E[e^(nlog(Y))] = mgf(log(Y), n). + +# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))` +function _truncnorm(d::Truncated{<:LogNormal}) + μ, σ = params(d.untruncated) + T = partype(d) + a = d.lower === nothing ? nothing : log(T(d.lower)) + b = d.upper === nothing ? nothing : log(T(d.upper)) + return truncated(Normal(μ, σ), a, b) +end + +mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1) + +function var(d::Truncated{<:LogNormal}) + tn = _truncnorm(d) + # Ensure the variance doesn't end up negative, which can occur due to numerical issues + return max(mgf(tn, 2) - mgf(tn, 1)^2, 0) +end + +function skewness(d::Truncated{<:LogNormal}) + tn = _truncnorm(d) + m1 = mgf(tn, 1) + m2 = mgf(tn, 2) + m3 = mgf(tn, 3) + sqm1 = m1^2 + v = m2 - sqm1 + return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v)) +end + +function kurtosis(d::Truncated{<:LogNormal}) + tn = _truncnorm(d) + m1 = mgf(tn, 1) + m2 = mgf(tn, 2) + m3 = mgf(tn, 3) + m4 = mgf(tn, 4) + v = m2 - m1^2 + return @horner(m1, m4, -4m3, 6m2, 0, -3) / v^2 - 3 +end + +# TODO: The entropy can be written "directly" as well, according to Mathematica, but +# the expression for it fills me with regret. There are some recognizable components, +# so a sufficiently motivated person could try to manually simplify it into something +# comprehensible. For reference, you can obtain the entropy with Mathematica like so: +# +# d = TruncatedDistribution[{a, b}, LogNormalDistribution[m, s]]; +# Expectation[-LogLikelihood[d, {x}], Distributed[x, d], +# Assumptions -> Element[x | m | s | a | b, Reals] && s > 0 && 0 < a < x < b] diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index 6fb334273..3db774206 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -118,6 +118,36 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous}) 0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z) end +function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real) + T = float(promote_type(partype(d), typeof(t))) + a = T(minimum(d)) + b = T(maximum(d)) + if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound? + return T(NaN) + elseif isinf(a) && isinf(b) && a != b + # Distribution is `Truncated`-wrapped but not actually truncated + return T(mgf(d.untruncated, t)) + elseif a == b + # Truncated to a Dirac distribution; this is `mgf(Dirac(a), t)` + return exp(a * t) + end + d0 = d.untruncated + μ = mean(d0) + σ = std(d0) + σ²t = σ^2 * t + a′ = (a - μ) / σ + b′ = (b - μ) / σ + stdnorm = Normal{T}(zero(T), one(T)) + # log((Φ(b′ - σ²t) - Φ(a′ - σ²t)) / (Φ(b′) - Φ(a′))) + logratio = if isfinite(a) && isfinite(b) # doubly truncated + logdiffcdf(stdnorm, b′ - σ²t, a′ - σ²t) - logdiffcdf(stdnorm, b′, a′) + elseif isfinite(a) # left truncated: b = ∞, Φ(b′) = Φ(b′ - σ²t) = 1 + logccdf(stdnorm, a′ - σ²t) - logccdf(stdnorm, a′) + else # isfinite(b), right truncated: a = ∞, Φ(a′) = Φ(a′ - σ²t) = 0 + logcdf(stdnorm, b′ - σ²t) - logcdf(stdnorm, b′) + end + return exp(t * (μ + σ²t / 2) + logratio) +end ### sampling diff --git a/test/runtests.jl b/test/runtests.jl index 583132c53..851ff480f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,7 @@ const tests = [ "truncated/exponential", "truncated/uniform", "truncated/discrete_uniform", + "truncated/lognormal", "censored", "univariate/continuous/normal", "univariate/continuous/laplace", diff --git a/test/testutils.jl b/test/testutils.jl index 2859856a7..6eb7ab5ad 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int) return r end +# Enables testing against values computed at high precision by transforming an expression +# that uses numeric literals and constants to wrap those in `big()`, similar to how the +# high-precision values for irrational constants are defined with `Base.@irrational` and +# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use. +bigly(x) = x +bigly(x::Symbol) = x in (:π, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x +bigly(x::Real) = Expr(:call, :big, x) +bigly(x::Expr) = (map!(bigly, x.args, x.args); x) +macro bigly(ex) + return esc(bigly(ex)) +end ################################################# # diff --git a/test/truncated/lognormal.jl b/test/truncated/lognormal.jl new file mode 100644 index 000000000..e5bf2b2a2 --- /dev/null +++ b/test/truncated/lognormal.jl @@ -0,0 +1,36 @@ +using Distributions, Test +using Distributions: expectation + +naive_moment(d, n, μ, σ²) = (σ = sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d)) + +@testset "Truncated log normal" begin + @testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat) + d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2))) + tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2) + bigmean = mgf(tn, 1) + bigvar = mgf(tn, 2) - bigmean^2 + @test @inferred(mean(d)) ≈ bigmean + @test @inferred(var(d)) ≈ bigvar + @test @inferred(median(d)) ≈ one(T) + @test @inferred(skewness(d)) ≈ naive_moment(d, 3, bigmean, bigvar) + @test @inferred(kurtosis(d)) ≈ naive_moment(d, 4, bigmean, bigvar) - big(3) + @test mean(d) isa T + end + @testset "Bound with no effect" begin + # Uses the example distribution from issue #709, though what's tested here is + # mostly unrelated to that issue (aside from `mean` not erroring). + # The specified left truncation at 0 has no effect for `LogNormal` + d1 = truncated(LogNormal(1, 5), 0, 1e5) + @test mean(d1) ≈ 0 atol=eps() + v1 = var(d1) + @test v1 ≈ 0 atol=eps() + # Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but + # numerically negative, which could cause downstream issues that assume a nonnegative + # variance + @test v1 >= 0 + # Compare results with not specifying a lower bound at all + d2 = truncated(LogNormal(1, 5); upper=1e5) + @test mean(d1) == mean(d2) + @test var(d1) == var(d2) + end +end diff --git a/test/truncated/normal.jl b/test/truncated/normal.jl index 9d287c026..43f853076 100644 --- a/test/truncated/normal.jl +++ b/test/truncated/normal.jl @@ -69,3 +69,37 @@ end @test isfinite(pdf(trunc, x)) end end + +@testset "Truncated normal MGF" begin + two = big(2) + sqrt2 = sqrt(two) + invsqrt2 = inv(sqrt2) + inv2sqrt2 = inv(two * sqrt2) + twoerfsqrt2 = two * erf(sqrt2) + + for T in (Float32, Float64, BigFloat) + d = truncated(Normal{T}(zero(T), one(T)), -2, 2) + @test @inferred(mgf(d, 0)) == 1 + @test @inferred(mgf(d, 1)) ≈ @bigly sqrt(ℯ) * (erf(invsqrt2) + erf(3 * invsqrt2)) / twoerfsqrt2 + @test @inferred(mgf(d, 2.5)) ≈ @bigly exp(25//8) * (erf(9 * inv2sqrt2) - erf(inv2sqrt2)) / twoerfsqrt2 + end + + d = truncated(Normal(3, 10), 7, 8) + @test mgf(d, 0) == 1 + @test mgf(d, 1) == 0 + + d = truncated(Normal(27, 3); lower=0) + @test mgf(d, 0) == 1 + @test mgf(d, 1) ≈ @bigly 2 * exp(63//2) / (1 + erf(9 * invsqrt2)) + @test mgf(d, 2.5) ≈ @bigly 2 * exp(765//8) / (1 + erf(9 * invsqrt2)) + + d = truncated(Normal(-5, 1); upper=-10) + @test mgf(d, 0) == 1 + @test mgf(d, 1) ≈ @bigly erfc(3 * sqrt2) / (exp(9//2) * erfc(5 * invsqrt2)) + + @test isnan(mgf(truncated(Normal(); upper=NaN), 0)) + + @test mgf(truncated(Normal(), -Inf, Inf), 1) == mgf(Normal(), 1) + + @test mgf(truncated(Normal(), 2, 2), 1) == exp(2) +end From c73a6e1ce28daaf7ba93592038b6123f21c11cf5 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 28 Jun 2024 17:34:02 -0700 Subject: [PATCH 2/3] Handle truncation outside support --- src/truncated/lognormal.jl | 4 ++-- test/truncated/lognormal.jl | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/truncated/lognormal.jl b/src/truncated/lognormal.jl index 03d15930c..1f6b45382 100644 --- a/src/truncated/lognormal.jl +++ b/src/truncated/lognormal.jl @@ -7,8 +7,8 @@ function _truncnorm(d::Truncated{<:LogNormal}) μ, σ = params(d.untruncated) T = partype(d) - a = d.lower === nothing ? nothing : log(T(d.lower)) - b = d.upper === nothing ? nothing : log(T(d.upper)) + a = d.lower === nothing || d.lower <= 0 ? nothing : log(T(d.lower)) + b = d.upper === nothing || isinf(d.upper) ? nothing : log(T(d.upper)) return truncated(Normal(μ, σ), a, b) end diff --git a/test/truncated/lognormal.jl b/test/truncated/lognormal.jl index e5bf2b2a2..ee9163e2b 100644 --- a/test/truncated/lognormal.jl +++ b/test/truncated/lognormal.jl @@ -32,5 +32,9 @@ naive_moment(d, n, μ, σ²) = (σ = sqrt(σ²); expectation(x -> ((x - μ) / σ d2 = truncated(LogNormal(1, 5); upper=1e5) @test mean(d1) == mean(d2) @test var(d1) == var(d2) + + # Truncated outside of support where taking a log would error + d3 = truncated(LogNormal(); lower=-1) + @test mean(d3) == mean(d3.untruncated) end end From 70c78100d0472843574b8eba85ee0d874403f524 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 28 Jun 2024 18:28:59 -0700 Subject: [PATCH 3/3] Blind guess at fixing type inference on 1.3 --- src/truncated/lognormal.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/truncated/lognormal.jl b/src/truncated/lognormal.jl index 1f6b45382..69c5297bf 100644 --- a/src/truncated/lognormal.jl +++ b/src/truncated/lognormal.jl @@ -6,10 +6,10 @@ # Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))` function _truncnorm(d::Truncated{<:LogNormal}) μ, σ = params(d.untruncated) - T = partype(d) + T = float(partype(d)) a = d.lower === nothing || d.lower <= 0 ? nothing : log(T(d.lower)) b = d.upper === nothing || isinf(d.upper) ? nothing : log(T(d.upper)) - return truncated(Normal(μ, σ), a, b) + return truncated(Normal{T}(T(μ), T(σ)), a, b) end mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)