Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault when calling rand with Bijectors #2074

Open
Red-Portal opened this issue Nov 8, 2024 · 4 comments
Open

Segfault when calling rand with Bijectors #2074

Red-Portal opened this issue Nov 8, 2024 · 4 comments

Comments

@Red-Portal
Copy link

Red-Portal commented Nov 8, 2024

Hi!

The following code segfaults on 1.10:

using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random, StableRNGs

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(abs2, θ)

function Bijectors.bijector(::TestProb1)
    return Bijectors.Stacked(
        [Base.Fix1(broadcast, log), identity],
        [1:1, 2:3],
    )
end

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(abs2, θ)

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
       ContinuousMultivariateDistribution
    location::L
    scale::S
    dist::D
    scale_eps::E
end

Base.length(q::MvLocationScale) = length(q.location)

Functors.@functor MvLocationScale (location, scale)

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,L}, num_samples::Int
) where {L}
    (; location, scale) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* randn(rng, n_dims, num_samples) .+ location
end

function restructure_ad_forward(restructure, params)
    return restructure(params)::typeof(restructure.model)
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; rng, problem, restructure) = aux
    q = restructure_ad_forward(restructure, params′)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5

    seed = (0x38bef07cf9cc549d)
    rng = StableRNG(seed)

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5)
        else
            Bijectors.TransformedDistribution(
                MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5),
                inverse(
                    Bijectors.Stacked(
                        [Base.Fix1(broadcast, log), identity],
                        [1:1, 2:d],
                    )
                )
            )
        end

        params, re = Optimisers.destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

This bug is very sensitive, and very seemingly minor changes (like changing the order of TestProb1 and TestProb2) immediately make it go away. As such it was pretty hard to contain, but the above seems to do. Below is the segfault error message.

[9153] signal (11.128): Segmentation fault
in expression starting at REPL[2]:1
runtime_generic_augfwd at /home/krkim/.julia/packages/Enzyme/RvNgp/src/rules/jitrules.jl:486
unknown function (ip: 0x7f0e60997750)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
estimate_repgradelbo_ad_forward at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:64 [inlined]
estimate_repgradelbo_ad_forward at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:0 [inlined]
augmented_julia_estimate_repgradelbo_ad_forward_7697_inner_1wrap at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:0
macro expansion at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8305 [inlined]
enzyme_call at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7868 [inlined]
AugmentedForwardThunk at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7705 [inlined]
autodiff at /home/krkim/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:384
unknown function (ip: 0x7f0e60db7dd1)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
autodiff at /home/krkim/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:512 [inlined]
main at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:94
unknown function (ip: 0x7f0ed2d17c92)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:877
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#run_repl#59 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91949.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#1013 at ./client.jl:437
jfptr_YY.1013_82918.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_main_repl at ./client.jl:421
exec_options at ./client.jl:338
_start at ./client.jl:557
jfptr__start_82944.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x7f0edf07ce07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 138892172 (Pool: 138519917; Big: 372255); GC: 197
zsh: segmentation fault (core dumped)  julia
@wsmoses
Copy link
Member

wsmoses commented Nov 8, 2024

hm so sadly this does not err on my computer

@Red-Portal
Copy link
Author

@wsmoses Hmm let me check a few things on my system. As a last resort, would it be useful for you if I can reproduce this in a Docker container?

@Red-Portal
Copy link
Author

Starting from a fresh .julia did the same. I created a docker container where I can reproduce the segfault. You can access it through:

docker pull kyrkim/enzymeissue2074
docker run -it kyrkim/enzymeissue2074  bash

And then copy-paste the snippet above on the pre-installed julia REPL.

@Red-Portal
Copy link
Author

@wsmoses Would this be enough to reproduce on your end? This bug is breaking all the Enzyme tests in AdvancedVI, so it would be really great to have it fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants