Skip to content

Commit

Permalink
drop require lock when not needed during loading to allow parallel pr…
Browse files Browse the repository at this point in the history
…ecompile loading (#56291)

Fixes `_require_search_from_serialized` to first acquire all
start_loading locks (using a deadlock-free batch-locking algorithm)
before doing stalechecks and the rest, so that all the global
computations happen behind the require_lock, then the rest can happen
behind module-specific locks, then (as before) extensions can be loaded
in parallel eventually after `require` returns.
  • Loading branch information
vtjnash authored Oct 25, 2024
1 parent 49e3b87 commit db3d816
Showing 1 changed file with 157 additions and 113 deletions.
270 changes: 157 additions & 113 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1261,47 +1261,52 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
assert_havelock(require_lock)
timing_imports = TIMING_IMPORTS[] > 0
try
if timing_imports
t_before = time_ns()
cumulative_compile_timing(true)
t_comp_before = cumulative_compile_time_ns()
end
if timing_imports
t_before = time_ns()
cumulative_compile_timing(true)
t_comp_before = cumulative_compile_time_ns()
end

for i in eachindex(depmods)
dep = depmods[i]
dep isa Module && continue
_, depkey, depbuild_id = dep::Tuple{String, PkgId, UInt128}
dep = something(maybe_loaded_precompile(depkey, depbuild_id))
@assert PkgId(dep) == depkey && module_build_id(dep) === depbuild_id
depmods[i] = dep
end
for i in eachindex(depmods)
dep = depmods[i]
dep isa Module && continue
_, depkey, depbuild_id = dep::Tuple{String, PkgId, UInt128}
dep = something(maybe_loaded_precompile(depkey, depbuild_id))
@assert PkgId(dep) == depkey && module_build_id(dep) === depbuild_id
depmods[i] = dep
end

if ocachepath !== nothing
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
sv = ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
else
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
sv = ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
end
if isa(sv, Exception)
return sv
end
unlock(require_lock) # temporarily _unlock_ during these operations
sv = try
if ocachepath !== nothing
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
else
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
end
finally
lock(require_lock)
end
if isa(sv, Exception)
return sv
end

restored = register_restored_modules(sv, pkg, path)
restored = register_restored_modules(sv, pkg, path)

for M in restored
M = M::Module
if parentmodule(M) === M && PkgId(M) == pkg
register && register_root_module(M)
if timing_imports
elapsed_time = time_ns() - t_before
comp_time, recomp_time = cumulative_compile_time_ns() .- t_comp_before
print_time_imports_report(M, elapsed_time, comp_time, recomp_time)
for M in restored
M = M::Module
if parentmodule(M) === M && PkgId(M) == pkg
register && register_root_module(M)
if timing_imports
elapsed_time = time_ns() - t_before
comp_time, recomp_time = cumulative_compile_time_ns() .- t_comp_before
print_time_imports_report(M, elapsed_time, comp_time, recomp_time)
end
return M
end
return M
end
end
return ErrorException("Required dependency $(repr("text/plain", pkg)) failed to load from a cache file.")
return ErrorException("Required dependency $(repr("text/plain", pkg)) failed to load from a cache file.")

finally
timing_imports && cumulative_compile_timing(false)
Expand Down Expand Up @@ -2020,13 +2025,46 @@ end
if staledeps === true
continue
end
try
staledeps, ocachefile, newbuild_id = staledeps::Tuple{Vector{Any}, Union{Nothing, String}, UInt128}
# finish checking staledeps module graph
for i in eachindex(staledeps)
staledeps, ocachefile, newbuild_id = staledeps::Tuple{Vector{Any}, Union{Nothing, String}, UInt128}
startedloading = length(staledeps) + 1
try # any exit from here (goto, break, continue, return) will end_loading
# finish checking staledeps module graph, while acquiring all start_loading locks
# so that concurrent require calls won't make any different decisions that might conflict with the decisions here
# note that start_loading will drop the loading lock if necessary
let i = 0
# start_loading here has a deadlock problem if we try to load `A,B,C` and `B,A,D` at the same time:
# it will claim A,B have a cycle, but really they just have an ambiguous order and need to be batch-acquired rather than singly
# solve that by making sure we can start_loading everything before allocating each of those and doing all the stale checks
while i < length(staledeps)
i += 1
dep = staledeps[i]
dep isa Module && continue
_, modkey, modbuild_id = dep::Tuple{String, PkgId, UInt128}
dep = canstart_loading(modkey, modbuild_id, stalecheck)
if dep isa Module
if PkgId(dep) == modkey && module_build_id(dep) === modbuild_id
staledeps[i] = dep
continue
else
@debug "Rejecting cache file $path_to_try because module $modkey got loaded at a different version than expected."
@goto check_next_path
end
continue
elseif dep === nothing
continue
end
wait(dep) # releases require_lock, so requires restarting this loop
i = 0
end
end
for i in reverse(eachindex(staledeps))
dep = staledeps[i]
dep isa Module && continue
modpath, modkey, modbuild_id = dep::Tuple{String, PkgId, UInt128}
# inline a call to start_loading here
@assert canstart_loading(modkey, modbuild_id, stalecheck) === nothing
package_locks[modkey] = current_task() => Threads.Condition(require_lock)
startedloading = i
modpaths = find_all_in_cache_path(modkey, DEPOT_PATH)
for modpath_to_try in modpaths
modstaledeps = stale_cachefile(modkey, modbuild_id, modpath, modpath_to_try; stalecheck)
Expand Down Expand Up @@ -2054,37 +2092,22 @@ end
end
end
# finish loading module graph into staledeps
# TODO: call all start_loading calls (in reverse order) before calling any _include_from_serialized, since start_loading will drop the loading lock
# n.b. this runs __init__ methods too early, so it is very unwise to have those, as they may see inconsistent loading state, causing them to fail unpredictably here
for i in eachindex(staledeps)
dep = staledeps[i]
dep isa Module && continue
modpath, modkey, modbuild_id, modcachepath, modstaledeps, modocachepath = dep::Tuple{String, PkgId, UInt128, String, Vector{Any}, Union{Nothing, String}}
dep = start_loading(modkey, modbuild_id, stalecheck)
while true
if dep isa Module
if PkgId(dep) == modkey && module_build_id(dep) === modbuild_id
break
else
@debug "Rejecting cache file $path_to_try because module $modkey got loaded at a different version than expected."
@goto check_next_path
end
end
if dep === nothing
try
set_pkgorigin_version_path(modkey, modpath)
dep = _include_from_serialized(modkey, modcachepath, modocachepath, modstaledeps; register = stalecheck)
finally
end_loading(modkey, dep)
end
if !isa(dep, Module)
@debug "Rejecting cache file $path_to_try because required dependency $modkey failed to load from cache file for $modcachepath." exception=dep
@goto check_next_path
else
push!(newdeps, modkey)
end
end
set_pkgorigin_version_path(modkey, modpath)
dep = _include_from_serialized(modkey, modcachepath, modocachepath, modstaledeps; register = stalecheck)
if !isa(dep, Module)
@debug "Rejecting cache file $path_to_try because required dependency $modkey failed to load from cache file for $modcachepath." exception=dep
@goto check_next_path
else
startedloading = i + 1
end_loading(modkey, dep)
staledeps[i] = dep
push!(newdeps, modkey)
end
staledeps[i] = dep
end
restored = maybe_loaded_precompile(pkg, newbuild_id)
if !isa(restored, Module)
Expand All @@ -2094,11 +2117,21 @@ end
@debug "Deserialization checks failed while attempting to load cache from $path_to_try" exception=restored
@label check_next_path
finally
# cancel all start_loading locks that were taken but not fulfilled before failing
for i in startedloading:length(staledeps)
dep = staledeps[i]
dep isa Module && continue
if dep isa Tuple{String, PkgId, UInt128}
_, modkey, _ = dep
else
_, modkey, _ = dep::Tuple{String, PkgId, UInt128, String, Vector{Any}, Union{Nothing, String}}
end
end_loading(modkey, nothing)
end
for modkey in newdeps
insert_extension_triggers(modkey)
stalecheck && run_package_callbacks(modkey)
end
empty!(newdeps)
end
end
end
Expand All @@ -2111,66 +2144,76 @@ const package_locks = Dict{PkgId,Pair{Task,Threads.Condition}}()
debug_loading_deadlocks::Bool = true # Enable a slightly more expensive, but more complete algorithm that can handle simultaneous tasks.
# This only triggers if you have multiple tasks trying to load the same package at the same time,
# so it is unlikely to make a performance difference normally.
function start_loading(modkey::PkgId, build_id::UInt128, stalecheck::Bool)
# handle recursive and concurrent calls to require

function canstart_loading(modkey::PkgId, build_id::UInt128, stalecheck::Bool)
assert_havelock(require_lock)
require_lock.reentrancy_cnt == 1 || throw(ConcurrencyViolationError("recursive call to start_loading"))
while true
loaded = stalecheck ? maybe_root_module(modkey) : nothing
loaded = stalecheck ? maybe_root_module(modkey) : nothing
loaded isa Module && return loaded
if build_id != UInt128(0)
loaded = maybe_loaded_precompile(modkey, build_id)
loaded isa Module && return loaded
if build_id != UInt128(0)
loaded = maybe_loaded_precompile(modkey, build_id)
loaded isa Module && return loaded
end
loading = get(package_locks, modkey, nothing)
loading === nothing && return nothing
# load already in progress for this module on the task
task, cond = loading
deps = String[modkey.name]
pkgid = modkey
assert_havelock(cond.lock)
if debug_loading_deadlocks && current_task() !== task
waiters = Dict{Task,Pair{Task,PkgId}}() # invert to track waiting tasks => loading tasks
for each in package_locks
cond2 = each[2][2]
assert_havelock(cond2.lock)
for waiting in cond2.waitq
push!(waiters, waiting => (each[2][1] => each[1]))
end
end
loading = get(package_locks, modkey, nothing)
if loading === nothing
package_locks[modkey] = current_task() => Threads.Condition(require_lock)
return nothing
while true
running = get(waiters, task, nothing)
running === nothing && break
task, pkgid = running
push!(deps, pkgid.name)
task === current_task() && break
end
# load already in progress for this module on the task
task, cond = loading
deps = String[modkey.name]
pkgid = modkey
assert_havelock(cond.lock)
if debug_loading_deadlocks && current_task() !== task
waiters = Dict{Task,Pair{Task,PkgId}}() # invert to track waiting tasks => loading tasks
for each in package_locks
cond2 = each[2][2]
assert_havelock(cond2.lock)
for waiting in cond2.waitq
push!(waiters, waiting => (each[2][1] => each[1]))
end
end
while true
running = get(waiters, task, nothing)
running === nothing && break
task, pkgid = running
push!(deps, pkgid.name)
task === current_task() && break
end
if current_task() === task
others = String[modkey.name] # repeat this to emphasize the cycle here
for each in package_locks # list the rest of the packages being loaded too
if each[2][1] === task
other = each[1].name
other == modkey.name || other == pkgid.name || push!(others, other)
end
end
if current_task() === task
others = String[modkey.name] # repeat this to emphasize the cycle here
for each in package_locks # list the rest of the packages being loaded too
if each[2][1] === task
other = each[1].name
other == modkey.name || other == pkgid.name || push!(others, other)
end
end
msg = sprint(deps, others) do io, deps, others
print(io, "deadlock detected in loading ")
join(io, deps, " -> ")
print(io, " -> ")
join(io, others, " && ")
end
throw(ConcurrencyViolationError(msg))
msg = sprint(deps, others) do io, deps, others
print(io, "deadlock detected in loading ")
join(io, deps, " -> ")
print(io, " -> ")
join(io, others, " && ")
end
throw(ConcurrencyViolationError(msg))
end
return cond
end

function start_loading(modkey::PkgId, build_id::UInt128, stalecheck::Bool)
# handle recursive and concurrent calls to require
while true
loaded = canstart_loading(modkey, build_id, stalecheck)
if loaded === nothing
package_locks[modkey] = current_task() => Threads.Condition(require_lock)
return nothing
elseif loaded isa Module
return loaded
end
loaded = wait(cond)
loaded = wait(loaded)
loaded isa Module && return loaded
end
end

function end_loading(modkey::PkgId, @nospecialize loaded)
assert_havelock(require_lock)
loading = pop!(package_locks, modkey)
notify(loading[2], loaded, all=true)
nothing
Expand Down Expand Up @@ -2650,6 +2693,7 @@ function _require(pkg::PkgId, env=nothing)
end

# load a serialized file directly, including dependencies (without checking staleness except for immediate conflicts)
# this does not call start_loading / end_loading, so can lead to some odd behaviors
function _require_from_serialized(uuidkey::PkgId, path::String, ocachepath::Union{String, Nothing}, sourcepath::String)
@lock require_lock begin
set_pkgorigin_version_path(uuidkey, sourcepath)
Expand Down

0 comments on commit db3d816

Please sign in to comment.