Skip to content
This repository has been archived by the owner on Sep 27, 2021. It is now read-only.

Commit

Permalink
update to [email protected]
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Aug 12, 2021
1 parent 74059c3 commit 81a259a
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 389 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ version = "0.1.4"

[deps]
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
JET = "0.4.2"
JET = "0.5.0"
MacroTools = "0.5.6"
julia = "1.6"

Expand Down
2 changes: 0 additions & 2 deletions docs/src/toolset/dispatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ These macros/functions are the entries of dispatch analysis:
```@docs
@report_dispatch
report_dispatch
@analyze_dispatch
analyze_dispatch
@test_nodispatch
test_nodispatch
```
Expand Down
16 changes: 0 additions & 16 deletions src/JETTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ using JET.JETInterfaces
import JET:
get_cache_key

import Test:
record

# usings
# ======

Expand All @@ -33,7 +30,6 @@ using JET:
State,
@invoke,
@isexpr,
gen_call_with_extracted_types_and_kwargs,
get_reports,
print_reports

Expand All @@ -54,16 +50,6 @@ using Core:
CodeInfo,
MethodInstance

using Test:
Test,
Pass, Fail, Broken, Error,
Threw,
get_testset,
TESTSET_PRINT_ENABLE,
AbstractTestSet,
DefaultTestSet,
FallbackTestSetException

# filters
# =======

Expand All @@ -79,8 +65,6 @@ end
include("dispatch.jl")

export
analyze_dispatch,
@analyze_dispatch,
report_dispatch,
@report_dispatch,
test_nodispatch,
Expand Down
220 changes: 21 additions & 199 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ JETInterfaces.get_msg(::Type{OptimizationFailureReport}, args...) =
return "failed to optimize" #: signature of this MethodInstance

function (::DispatchAnalysisPass)(::Type{OptimizationFailureReport}, analyzer::DispatchAnalyzer, frame::InferenceState)
report!(OptimizationFailureReport, analyzer, frame.linfo)
add_new_report!(OptimizationFailureReport(analyzer, frame.linfo), analyzer)
end

@reportdef struct RuntimeDispatchReport <: InferenceErrorReport end
Expand All @@ -190,7 +190,7 @@ function (::DispatchAnalysisPass)(::Type{RuntimeDispatchReport}, analyzer::Dispa
ft = widenconst(argextype(first(x.args), opt.src, sptypes, slottypes))
ft <: Builtin && continue # ignore `:call`s of language intrinsics
if analyzer.function_filter(ft)
report!(RuntimeDispatchReport, analyzer, (opt, pc))
add_new_report!(RuntimeDispatchReport(analyzer, (opt, pc)), analyzer)
end
end
end
Expand Down Expand Up @@ -234,7 +234,7 @@ function CC.finish(frame::InferenceState, analyzer::DispatchAnalyzer)
if isa(frame.result.src, OptimizationState)
push!(opts, true)
else
report_pass!(OptimizationFailureReport, analyzer, frame)
ReportPass(analyzer)(OptimizationFailureReport, analyzer, frame)
push!(opts, false)
end
end
Expand All @@ -250,10 +250,10 @@ function CC.finish!(analyzer::DispatchAnalyzer, caller::InferenceResult)
if popfirst!(analyzer.opts) # optimization happened
if isa(ret, Const) # the optimization was very successful, nothing to report
elseif isa(ret, OptimizationState) # compiler cached the optimized IR, just analyze it
report_pass!(RuntimeDispatchReport, analyzer, ret)
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, ret)
elseif isa(ret, CodeInfo) # compiler didn't cache the optimized IR, but `finish!(::AbstractInterpreter, ::InferenceResult)` transformed it to `opt.src`, so we can analyze it
@assert isa(opt, OptimizationState) && opt.src === ret
report_pass!(RuntimeDispatchReport, analyzer, opt)
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, opt)
else
Core.eval(@__MODULE__, :(ret = $ret))
throw("unexpected state happened, inspect $(@__MODULE__).ret") # this pass should never happen
Expand All @@ -273,52 +273,21 @@ end # @static if isdefined(CC, :finish!)
# =======

"""
analyze_dispatch(f, types = Tuple{}; jetconfigs...) -> (analyzer::DispatchAnalyzer, frame::Union{InferenceFrame,Nothing})
analyze_dispatch(tt::Type{<:Tuple}; jetconfigs...) -> (analyzer::DispatchAnalyzer, frame::Union{InferenceFrame,Nothing})
report_dispatch(f, types = Tuple{}; jetconfigs...) -> JETCallResult
report_dispatch(tt::Type{<:Tuple}; jetconfigs...) -> JETCallResult
Analyzes the generic function call with the given type signature, and returns:
- `analyzer::DispatchAnalyzer`: contains analyzed optimization failures and runtime dispatch points
- `frame::Union{InferenceFrame,Nothing}`: the final state of the abstract interpretation,
or `nothing` if `f` is a generator and the code generation has been failed
"""
function analyze_dispatch(@nospecialize(args...);
analyzer = DispatchAnalyzer,
jetconfigs...)
@assert analyzer === DispatchAnalyzer "analyzer is fixed to $DispatchAnalyzer"
return analyze_call(args...; analyzer, jetconfigs...)
end

"""
report_dispatch(f, types = Tuple{}; jetconfigs...) -> result_type::Any
report_dispatch(tt::Type{<:Tuple}; jetconfigs...) -> result_type::Any
Analyzes the generic function call with the given type signature, and then prints detected
optimization failures and runtime dispatch points to `stdout`, and finally returns the result
type of the call.
Analyzes the generic function call with the given type signature with `DispatchAnalyzer`,
which collects optimization failures and runtime dispatches involved within the call stack.
"""
function report_dispatch(@nospecialize(args...);
analyzer = DispatchAnalyzer,
jetconfigs...)
@assert analyzer === DispatchAnalyzer "analyzer is fixed to $DispatchAnalyzer"
if !(analyzer === DispatchAnalyzer)
throw(ArgumentError("`analyzer` is fixed to $DispatchAnalyzer"))
end
return report_call(args...; analyzer, jetconfigs...)
end

"""
@analyze_dispatch [jetconfigs...] f(args...)
Evaluates the arguments to the function call, determines its types, and then calls
[`analyze_dispatch`](@ref) on the resulting expression.
As with `@code_typed` and its family, any of [JET configurations](https://aviatesk.github.io/JET.jl/dev/config/)
or [dispatch analysis specific configurations](@ref dispatch-analysis-configurations) can be given as the optional arguments like this:
```julia
# reports `rand(::Type{Bool})` with `unoptimize_throw_blocks` configuration turned on
julia> @analyze_dispatch unoptimize_throw_blocks=true rand(Bool)
```
"""
macro analyze_dispatch(ex0...)
return gen_call_with_extracted_types_and_kwargs(__module__, :analyze_dispatch, ex0)
end

"""
@report_dispatch [jetconfigs...] f(args...)
Expand All @@ -328,11 +297,11 @@ As with `@code_typed` and its family, any of [JET configurations](https://aviate
or [dispatch analysis specific configurations](@ref dispatch-analysis-configurations) can be given as the optional arguments like this:
```julia
# reports `rand(::Type{Bool})` with `unoptimize_throw_blocks` configuration turned on
julia> @report_call unoptimize_throw_blocks=true rand(Bool)
julia> @report_dispatch unoptimize_throw_blocks=true rand(Bool)
```
"""
macro report_dispatch(ex0...)
return gen_call_with_extracted_types_and_kwargs(__module__, :report_dispatch, ex0)
return var"@report_call"(__source__, __module__, :(analyzer=$DispatchAnalyzer), ex0...)
end

# Test integration
Expand All @@ -353,7 +322,7 @@ Test Passed
Expression: #= none:1 =# JETTest.@test_nodispatch sincos(10)
```
As with [`@report_dispatch`](@ref) or [`@analyze_dispatch`](@ref), any of [JET configurations](https://aviatesk.github.io/JET.jl/dev/config/)
As with [`@report_dispatch`](@ref), any of [JET configurations](https://aviatesk.github.io/JET.jl/dev/config/)
or [dispatch analysis specific configurations](@ref dispatch-analysis-configurations) can be given as the optional arguments like this:
```julia
julia> function f(n)
Expand Down Expand Up @@ -399,84 +368,7 @@ ERROR: Some tests did not pass: 1 passed, 1 failed, 0 errored, 1 broken.
```
"""
macro test_nodispatch(ex0...)
ex0 = collect(ex0)

local broken = nothing
local skip = nothing
idx = Int[]
for (i,x) in enumerate(ex0)
if iskwarg(x)
key, val = x.args
if key === :broken
if !isnothing(broken)
error("invalid test macro call: cannot set `broken` keyword multiple times")
end
broken = esc(val)
push!(idx, i)
elseif key === :skip
if !isnothing(skip)
error("invalid test macro call: cannot set `skip` keyword multiple times")
end
skip = esc(val)
push!(idx, i)
end
end
end
if !isnothing(broken) && !isnothing(skip)
error("invalid test macro call: cannot set both `skip` and `broken` keywords")
end
deleteat!(ex0, idx)

testres, orig_expr = test_dispatch_exs(ex0, __module__, __source__)

return quote
if $(!isnothing(skip) && skip)
$record($get_testset(), $Broken(:skipped, $orig_expr))
else
testres = $testres
if $(!isnothing(broken) && broken)
if isa(testres, $DispatchTestFailure)
testres = $Broken(:test_nodispatch, $orig_expr)
elseif isa(testres, $Pass)
testres = $Error(:test_unbroken, $orig_expr, nothing, nothing, $(QuoteNode(__source__)))
end
else
isa(testres, $Pass) || ccall(:jl_breakpoint, $Cvoid, ($Any,), testres)
end
$record($get_testset(), testres)
end
end
end

iskwarg(@nospecialize(x)) = @isexpr(x, :(=))

get_exceptions() = @static if isdefined(Base, :current_exceptions)
Base.current_exceptions()
else
Base.catch_stack()
end
@static if !hasfield(Pass, :source)
Pass(test_type::Symbol, orig_expr, data, thrown, source) = Pass(test_type, orig_expr, data, thrown)
end

function test_dispatch_exs(ex0, m, source)
analyzer_call = gen_call_with_extracted_types_and_kwargs(m, :analyze_dispatch, ex0)
orig_expr = QuoteNode(
Expr(:macrocall, GlobalRef(JETTest, Symbol("@test_nodispatch")), source, ex0...))
source = QuoteNode(source)
testres = :(try
analyzer, frame = $analyzer_call
reports = $get_reports(analyzer)
if $length(reports) == 0
$Pass(:test_nodispatch, $orig_expr, nothing, nothing, $source)
else
$DispatchTestFailure($orig_expr, $source, reports)
end
catch err
isa(err, $InterruptException) && rethrow()
$Error(:test_error, $orig_expr, err, $get_exceptions(), $source)
end) |> Base.remove_linenums!
return testres, orig_expr
return var"@test_call"(__source__, __module__, :(analyzer=$DispatchAnalyzer), ex0...)
end

"""
Expand All @@ -488,80 +380,10 @@ Except that it takes a type signature rather than a call expression, this functi
in the same way as [`@test_nodispatch`](@ref).
"""
function test_nodispatch(@nospecialize(args...);
broken::Bool = false, skip::Bool = false,
jetconfigs...)
source = LineNumberNode(@__LINE__, @__FILE__)
kwargs = map(((k,v),)->Expr(:kw, k, v), collect(jetconfigs))
orig_expr = :(JETTest.test_nodispatch($(args...); $(kwargs...)))

if skip
record(get_testset(), Broken(:skipped, orig_expr))
else
testres = try
analyzer, frame = analyze_dispatch(args...; jetconfigs...)
reports = get_reports(analyzer)
if length(reports) == 0
Pass(:test_nodispatch, orig_expr, nothing, nothing, source)
else
DispatchTestFailure(orig_expr, source, reports)
end
catch err
isa(err, InterruptException) && rethrow()
Error(:test_error, orig_expr, err, get_exceptions(), source)
end

if broken
if isa(testres, DispatchTestFailure)
testres = Broken(:test_nodispatch, orig_expr)
elseif isa(testres, Pass)
testres = Error(:test_unbroken, orig_expr, nothing, nothing, source)
end
else
isa(testres, Pass) || ccall(:jl_breakpoint, Cvoid, (Any,), testres)
end
record(get_testset(), testres)
end
end

# NOTE we will just show abstract call strack, and won't show backtrace of actual test executions

struct DispatchTestFailure <: Test.Result
orig_expr::Expr
source::LineNumberNode
reports::Vector{InferenceErrorReport}
end

const TEST_INDENTS = " "

function Base.show(io::IO, t::DispatchTestFailure)
printstyled(io, "Dispatch Test Failed"; bold=true, color=Base.error_color())
print(io, " at ")
printstyled(io, something(t.source.file, :none), ":", t.source.line, "\n"; bold=true, color=:default)
println(io, TEST_INDENTS, "Expression: ", t.orig_expr)
# print abstract call stack, with appropriate indents
_, ctx = Base.unwrapcontext(io)
buf = IOBuffer()
ioctx = IOContext(buf, ctx)
print_reports(ioctx, t.reports) # TODO kwargs support
lines = replace(String(take!(buf)), '\n'=>string('\n',TEST_INDENTS))
print(io, TEST_INDENTS, lines)
end

Base.show(io::IO, ::MIME"application/prs.juno.inline", t::DispatchTestFailure) =
return t

function Test.record(::Test.FallbackTestSet, t::DispatchTestFailure)
println(t)
throw(FallbackTestSetException("There was an error during testing"))
end

function Test.record(ts::Test.DefaultTestSet, t::DispatchTestFailure)
if Test.TESTSET_PRINT_ENABLE[]
printstyled(ts.description, ": ", color=:white)
print(t)
println()
analyzer = DispatchAnalyzer,
kwargs...)
if !(analyzer === DispatchAnalyzer)
throw(ArgumentError("`analyzer` is fixed to $DispatchAnalyzer"))
end
# HACK convert to `Fail` so that test summarization works correctly
push!(ts.results, Fail(:test_nodispatch, t.orig_expr, nothing, nothing, t.source))
return t
return test_call(args...; analyzer, kwargs...)
end
2 changes: 1 addition & 1 deletion src/legacy/dispatch
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function CC.optimize(analyzer::AbstractAnalyzer, opt::OptimizationState, params:

concrete_frame = analyzer.concrete_frame
if (isnothing(concrete_frame) || concrete_frame::Bool) && analyzer.frame_filter(opt)
report_pass!(RuntimeDispatchReport, analyzer, opt)
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, opt)
end

return ret
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using JETTest, Test
return ft !== typeof(Core.Compiler.widenconst) # `widenconst` is very untyped, ignore
end

@test_nodispatch frame_filter=frame_filter function_filter=function_filter skip_nonconcrete_calls=false analyze_dispatch(sin, (Int,))
@test_nodispatch frame_filter=frame_filter function_filter=function_filter skip_nonconcrete_calls=false JETTest.JET.analyze_gf_by_type!(
JETTest.DispatchAnalyzer(), Tuple{typeof(sin),Int})
end
end
Loading

0 comments on commit 81a259a

Please sign in to comment.