Skip to content

Commit

Permalink
AbstractInterpreter: add a hook to customize bestguess calculation
Browse files Browse the repository at this point in the history
Currently, the code that updates `bestguess` using `ReturnNode`
information includes hardcodes that relate to `Conditional` and
`LimitedAccuracy`. These behaviors are actually lattice-dependent and
therefore should be overloadable by `AbstractInterpreter`.

Additionally, particularly in Diffractor, a clever strategy is required
to update return types in a way that it takes into account information
from both the original method and its rule method
(xref: JuliaDiff/Diffractor.jl#202). This also requires such an overload
to exist.
In response to these needs, this commit introduces an implementation of
a hook named `update_bestguess!`.
  • Loading branch information
aviatesk committed Jul 31, 2023
1 parent 441fcb1 commit 8d57b1b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 48 deletions.
67 changes: 37 additions & 30 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2892,17 +2892,49 @@ function init_vartable!(vartable::VarTable, frame::InferenceState)
return vartable
end

function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
currstate::VarTable, @nospecialize(rt))
bestguess = frame.bestguess
nargs = narguments(frame, #=include_va=#false)
slottypes = frame.slottypes
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
slot_id = rt.slot
old_id_type = slottypes[slot_id]
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
bestguess = InterConditional(slot_id, Bottom, old_id_type)
end
end
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
𝕃ₚ = ipo_lattice(interp)
if !(𝕃ₚ, rt, bestguess)
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
frame.bestguess = tmerge(𝕃ₚ, bestguess, rt) # new (wider) return type for frame
return true
else
return false
end
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
frame.dont_work_on_me = true # mark that this function is currently on the stack
W = frame.ip
nargs = narguments(frame, #=include_va=#false)
slottypes = frame.slottypes
ssavaluetypes = frame.ssavaluetypes
bbs = frame.cfg.blocks
nbbs = length(bbs)
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
𝕃ᵢ = typeinf_lattice(interp)

currbb = frame.currbb
if currbb != 1
Expand Down Expand Up @@ -3003,35 +3035,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end
elseif isa(stmt, ReturnNode)
bestguess = frame.bestguess
rt = abstract_eval_value(interp, stmt.val, currstate, frame)
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
let slot_id = rt.slot
old_id_type = slottypes[slot_id]
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
bestguess = InterConditional(slot_id, Bottom, old_id_type)
end
end
end
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
if !(𝕃ₚ, rt, bestguess)
# new (wider) return type for frame
bestguess = tmerge(𝕃ₚ, bestguess, rt)
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
frame.bestguess = bestguess
if update_bestguess!(interp, frame, currstate, rt)
for (caller, caller_pc) in frame.cycle_backedges
if !(caller.ssavaluetypes[caller_pc] === Any)
if caller.ssavaluetypes[caller_pc] !== Any
# no reason to revisit if that call-site doesn't affect the final result
push!(caller.ip, block_for_inst(caller.cfg, caller_pc))
end
Expand Down
39 changes: 21 additions & 18 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,26 +870,10 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# since the inliner will request to use it later
cache = :local
else
rt = cached_return_type(code)
effects = ipo_effects(code)
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
rettype_const = code.rettype_const
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
rettype = PartialStruct(rettype, rettype_const)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
rettype = rettype_const
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
rettype = rettype_const
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
rettype = rettype_const
else
rettype = Const(rettype_const)
end
end
return EdgeCallResult(rettype, mi, effects)
return EdgeCallResult(rt, mi, effects)
end
else
cache = :global # cache edge targets by default
Expand Down Expand Up @@ -933,6 +917,25 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame))
end

function cached_return_type(code::CodeInstance)
rettype = code.rettype
isdefined(code, :rettype_const) || return rettype
rettype_const = code.rettype_const
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(rettype, rettype_const)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
return rettype_const
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
return rettype_const
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
return rettype_const
else
return Const(rettype_const)
end
end

#### entry points for inferring a MethodInstance given a type signature ####

# compute an inferred AST and return type
Expand Down

0 comments on commit 8d57b1b

Please sign in to comment.