From b5e13af9d40c90077ceb8c6b28d8610a487704fc Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 2 Oct 2024 14:25:18 +0200 Subject: [PATCH 1/2] Fix the constructor of `DiscreteNonParametric` --- .../discrete/discretenonparametric.jl | 22 ++++++++++++------ src/utils.jl | 17 ++++++++++++++ test/univariate/discrete/categorical.jl | 17 ++++++++++++++ .../discrete/discretenonparametric.jl | 23 ++++++++++++++++++- 4 files changed, 71 insertions(+), 8 deletions(-) diff --git a/src/univariate/discrete/discretenonparametric.jl b/src/univariate/discrete/discretenonparametric.jl index 8e1eefab6e..8f242ec234 100644 --- a/src/univariate/discrete/discretenonparametric.jl +++ b/src/univariate/discrete/discretenonparametric.jl @@ -23,21 +23,29 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where { T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} - check_args || return new{T,P,Ts,Ps}(xs, ps) @check_args( DiscreteNonParametric, (length(xs) == length(ps), "length of support and probability vector must be equal"), (ps, isprobvec(ps), "vector is not a probability vector"), - (xs, allunique(xs), "support must contain only unique elements"), + (xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"), ) - sort_order = sortperm(xs) - new{T,P,Ts,Ps}(xs[sort_order], ps[sort_order]) + new{T,P,Ts,Ps}(xs, ps) end end -DiscreteNonParametric(vs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where { - T<:Real,P<:Real} = - DiscreteNonParametric{T,P,typeof(vs),typeof(ps)}(vs, ps; check_args=check_args) +function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real} + # We always sort the support unless it can be deduced from the type of the support that it is sorted. + # Sorting can be skipped for all inputs by using the inner constructor. + if xs isa AbstractUnitRange + sortedxs = xs + sortedps = ps + else + sort_order = sortperm(xs) + sortedxs = xs[sort_order] + sortedps = ps[sort_order] + end + return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=check_args) +end Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T diff --git a/src/utils.jl b/src/utils.jl index a2c9aaffa9..442b863d7e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -97,6 +97,23 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12 isprobvec(p::AbstractVector{<:Real}) = all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p))) +issorted_allunique(xs::AbstractUnitRange{<:Real}) = true +function issorted_allunique(xs::AbstractVector{<:Real}) + xi_state = iterate(xs) + if xi_state === nothing + return true + end + xi, state = xi_state + while (xj_state = iterate(xs, state)) !== nothing + xj, state = xj_state + if xj <= xi + return false + end + xi = xj + end + return true +end + # get a type wide enough to represent all a distributions's parameters # (if the distribution is parametric) # if the distribution is not parametric, we need this to be a float so that diff --git a/test/univariate/discrete/categorical.jl b/test/univariate/discrete/categorical.jl index 6d87d4dc86..96425a204f 100644 --- a/test/univariate/discrete/categorical.jl +++ b/test/univariate/discrete/categorical.jl @@ -137,4 +137,21 @@ end @test count(==(1e8), priorities[iat]) >= 13 end +@testset "AbstractVector" begin + # issue #1084 + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + d = @inferred(Categorical(p)) + @test d isa Categorical{Float64, typeof(p)} + @test d.p === p + + # #1832 + x = rand(3,5) + x ./= sum(x; dims=1) + c = Categorical.(eachcol(x)) + @test c isa Vector{<:Categorical} + @test all(ci.p isa SubArray for ci in c) +end + end diff --git a/test/univariate/discrete/discretenonparametric.jl b/test/univariate/discrete/discretenonparametric.jl index 68354a064a..b81d5b52a9 100644 --- a/test/univariate/discrete/discretenonparametric.jl +++ b/test/univariate/discrete/discretenonparametric.jl @@ -213,4 +213,25 @@ end # Different types @test DiscreteNonParametric(1:2, [0.5, 0.5]) == DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) @test DiscreteNonParametric(1:2, [0.5, 0.5]) ≈ DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) -end \ No newline at end of file +end + +@testset "AbstractVector (issue #1084)" begin + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + + d = @inferred(DiscreteNonParametric(Base.OneTo(5), p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:5, p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:1:5, p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p + d = @inferred(DiscreteNonParametric([1, 2, 3, 4, 5], p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p +end From b7283e732500f6ffa86bc61abe3fe4fe1d2c0596 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 2 Oct 2024 16:22:19 +0200 Subject: [PATCH 2/2] Improve errors and test them --- .../discrete/discretenonparametric.jl | 29 +++++++++++++++---- .../discrete/discretenonparametric.jl | 22 +++++++++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/univariate/discrete/discretenonparametric.jl b/src/univariate/discrete/discretenonparametric.jl index 8f242ec234..cc339fc075 100644 --- a/src/univariate/discrete/discretenonparametric.jl +++ b/src/univariate/discrete/discretenonparametric.jl @@ -23,17 +23,27 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where { T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} + let xs = xs, ps = ps + @check_args( + DiscreteNonParametric, + (length(xs) == length(ps), "length of support and probability vector must be equal"), + (ps, isprobvec(ps), "vector is not a probability vector"), + (xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"), + ) + end + new{T,P,Ts,Ps}(xs, ps) + end +end + +function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real} + # These checks are performed before sorting the support since we do not want to throw a `BoundsError` when the lengths do not match + let xs = xs, ps = ps @check_args( DiscreteNonParametric, (length(xs) == length(ps), "length of support and probability vector must be equal"), (ps, isprobvec(ps), "vector is not a probability vector"), - (xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"), ) - new{T,P,Ts,Ps}(xs, ps) end -end - -function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real} # We always sort the support unless it can be deduced from the type of the support that it is sorted. # Sorting can be skipped for all inputs by using the inner constructor. if xs isa AbstractUnitRange @@ -43,8 +53,15 @@ function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; che sort_order = sortperm(xs) sortedxs = xs[sort_order] sortedps = ps[sort_order] + # It is more efficient to perform this check once the array is sorted + let sortedxs = sortedxs + @check_args( + DiscreteNonParametric, + (sortedxs, issorted_allunique(sortedxs), "support must contain only unique elements"), + ) + end end - return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=check_args) + return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=false) end Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T diff --git a/test/univariate/discrete/discretenonparametric.jl b/test/univariate/discrete/discretenonparametric.jl index b81d5b52a9..128207d74b 100644 --- a/test/univariate/discrete/discretenonparametric.jl +++ b/test/univariate/discrete/discretenonparametric.jl @@ -14,7 +14,8 @@ rng = MersenneTwister(123) d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2]) -@test !(d ≈ DiscreteNonParametric([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false)) +# In the outer constructor, the support is always sorted, regardless of whether `check_args = false` or `check_args = true` +@test d ≈ DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2], check_args=false) @test d ≈ DiscreteNonParametric([-60., 40., 80, 120], [.2, .4, .3, .1], check_args=false) # Invalid probability @@ -23,6 +24,25 @@ d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2]) # Invalid probability, but no arg check DiscreteNonParametric([40., 80, 120, -60], [.5, .3, .1, .2], check_args=false) +# Invalid support +@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2]) +@test_throws DomainError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1]) +@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1]) + +# Invalid support but no arg check +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false) +DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false) +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false) + +# Mismatch between support and probabilities +@test_throws ArgumentError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3]) +@test_throws ArgumentError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3]) + +# Mismatch between support and probabilities but no arg check +@test_throws BoundsError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3], check_args=false) # sorting errors +DiscreteNonParametric(1:4, [.2, .4, .3], check_args=false) # no sorting, hence no `BoundsError` +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3], check_args=false) + test_range(d) vs = Distributions.get_evalsamples(d, 0.00001) test_evaluation(d, vs, true)