Skip to content

Commit

Permalink
bugfix/reflection-parsing (#192)
Browse files Browse the repository at this point in the history
* updated function signature check
* Improved the extract_defaults() parsing strategy to be more robust and not depend on a certain ordering of CodeInfo objects from a function signature
  • Loading branch information
ndortega authored May 1, 2024
1 parent 299bcd9 commit ce05310
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 92 deletions.
198 changes: 108 additions & 90 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ function getargvalue(arg)
return isa(arg, GlobalRef) ? getfield(arg.mod, arg.name) : arg
end


"""
Return all parameter name & types and keyword argument names from a function
"""
Expand Down Expand Up @@ -49,41 +48,6 @@ function getsignames(func_methods::Base.MethodList; start=2)
return arg_names, arg_types, kwarg_names
end


"""
This function extract default values from a list of expressions
The args parameter is a vector that follows this shape like This
[
function_name...,
UnionTypes...,
kwarg defaults...,
slots...,
positional defaults...
]
"""
function splitargs(args::Vector, func_name::Symbol)
param_defaults = Vector{Any}()
kwarg_defaults = Vector{Any}()
encountered_slot = false

for arg in args
# # check if this is a function name
# if arg isa Core.GlobalRef && startswith(String(arg.name), "$func_name#")
# continue
# elseif isa(arg, Core.SSAValue)
# continue
if isa(arg, Core.SlotNumber)
encountered_slot = true
elseif !encountered_slot
push!(kwarg_defaults, getargvalue(arg))
else
push!(param_defaults, getargvalue(arg))
end
end
return param_defaults, kwarg_defaults
end

function walkargs(predicate::Function, expr)
if isdefined(expr, :args)
for arg in expr.args
Expand Down Expand Up @@ -154,14 +118,6 @@ function reconstruct(info::Core.CodeInfo, func_name::Symbol)
end
end


# exit early if no sig is found
if isnothing(sig_index)
# if there is not function signature, then we filter out the values directly
return [arg for arg in info.code if !is_lowered(arg)]
end


# Recursively build an expression of the actual type of each argument in the function signature
evaled_sig = rebuild!(statements[sig_index])

Expand All @@ -188,71 +144,133 @@ function reconstruct(info::Core.CodeInfo, func_name::Symbol)
return default_values
end



"""
Return true if the given object is one of the types found in lowered IR code
Values were taken from here: https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form
I've purposely ignored the Core.GlobalRef type, because we can just lookup the underlying value at runtime
Returns true if the CodeInfo object has a function signature
Most funtion signatures follow this general pattern
- The second to last expression is used as the function signature
- The last argument is a Return node
Below are a couple different examples of this in pattern in action:
# Standard function signature
CodeInfo(
1 ─ %1 = (#self#)(req, a, path, qparams, 23)
└── return %1
)
# Extractor example (as a default value)
CodeInfo(
1 ─ #22 = %new(Main.RunTests.ExtractorTests.:(var"#22#37"))
│ %2 = #22
│ %3 = Main.RunTests.ExtractorTests.Header(Main.RunTests.ExtractorTests.Sample, %2)
│ %4 = (#self#)(req, %3)
└── return %4
)
# This kind of function signature happens when a keyword argument is defined without at default value
CodeInfo(
1 ─ %1 = "default"
│ c = %1
│ %3 = true
│ d = %3
│ %5 = Core.UndefKeywordError(:request)
│ %6 = Core.throw(%5)
│ request = %6
│ %8 = Core.getfield(#self#, Symbol("#8#9"))
│ %9 = c
│ %10 = d
│ %11 = request
│ %12 = (%8)(%9, %10, %11, #self#, a, b)
└── return %12
)
"""
function is_lowered(instance::Any) :: Bool
return instance isa Union{
Expr, Core.SlotNumber, Core.Argument, Core.CodeInfo,
Core.GotoNode, Core.GotoIfNot, Core.ReturnNode, Core.QuoteNode,
Core.SSAValue, Core.NewvarNode
}
end
function has_sig_expr(c::Core.CodeInfo) :: Bool

statements_length = length(c.code)

# prevent index out of bounds
if statements_length < 2
return false
end

# """
# Returns true if the CodeInfo block has an expression where the first arg is a SlotNumber with id 1
# """
# function has_sig_expr(info::Core.CodeInfo) :: Bool
# for expr in info.code
# # identify the function signature
# if isdefined(expr, :args) && expr.head == :call
# first_arg = first(expr.args)
# if first_arg isa Core.SlotNumber && first_arg.id == 1
# return true
# end
# end
# end
# return false
# end
# check for our pattern of a function signature followed by a return statement
last_expr = c.code[statements_length]
second_to_last_expr = c.code[statements_length - 1]

if last_expr isa Core.ReturnNode && second_to_last_expr isa Expr && second_to_last_expr.head == :call
# recursivley search expression to see if we have a SlotNumber(1) in the args
return walkargs(second_to_last_expr) do arg
return isa(arg, Core.SlotNumber) && arg.id == 1
end
end

function extract_defaults(info::Vector{Core.CodeInfo}, func_name::Symbol, param_names, kwarg_names; start=2)
return false
end

"""
Given a list of CodeInfo objects, extract any default values assigned to parameters & keyword arguments
"""
function extract_defaults(info::Vector{Core.CodeInfo}, func_name::Symbol, param_names::Vector{Symbol}, kwarg_names::Vector{Symbol}; start=2)

# These store the mapping between parameter names and their default values
param_defaults = Dict()
kwarg_defaults = Dict()

# Given the params, we can take an educated guess and map the slot number to the parameter name
slot_mapping = Dict(i + 1 => p for (i, p) in enumerate(vcat(param_names, kwarg_names)))

# skip parsing if no parameters or keyword arguments are found
if isempty(param_names) && isempty(kwarg_names)
return param_defaults, kwarg_defaults
end

for c in info

# skip parsing function bodys which normally start with newvarnodes
if first(c.code) isa Core.NewvarNode
# skip code info objects that don't have a function signature
if !has_sig_expr(c)
continue
end

# rebuild the function signature with the default values included
sig_args = reconstruct(c, func_name)
p_defaults, kw_defaults = splitargs(sig_args, func_name)

# store the default values for params
if !isempty(p_defaults)
# we reverse, because each subsequent expression should show more defaults,
# so we need to keep updating them as we see them.
for (name, value) in zip(reverse(param_names), reverse(p_defaults))
param_defaults[name] = value
end
end

# store the default values for kwargs
if !isempty(kw_defaults)
for (name, value) in zip(kwarg_names, kw_defaults)
kwarg_defaults[name] = value
sig_length = length(sig_args)
self_index = findfirst([isa(x, Core.SlotNumber) && x.id == 1 for x in sig_args])

for (index, arg) in enumerate(sig_args)

# for keyword arguments
if index < self_index

# derive the current slot name
slot_number = sig_length - abs(self_index - index) + 1
slot_name = slot_mapping[slot_number]

# don't store slot numbers when no default is given
value = getargvalue(arg)
if !isa(value, Core.SlotNumber)
kwarg_defaults[slot_name] = value
end

# for regular arguments
elseif index > self_index

# derive the current slot name
slot_number = abs(self_index - index) + 1
slot_name = slot_mapping[slot_number]

# don't store slot numbers when no default is given
value = getargvalue(arg)
if !isa(value, Core.SlotNumber)
param_defaults[slot_name] = value
end
end
end

end
end

return param_defaults, kwarg_defaults
Expand Down
75 changes: 74 additions & 1 deletion test/reflectiontests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module ReflectionTests

using Test
using Base: @kwdef
using Oxygen: splitdef
using Oxygen: splitdef, Json
using Oxygen.Core.Reflection: getsignames, parsetype, kwarg_struct_builder


global message = Dict("message" => "Hello, World!")

struct Person
name::String
age::Int
Expand Down Expand Up @@ -223,4 +226,74 @@ end
end



@testset "splitdef extractor default value" begin
# Define a function for testing
f = function(a::Int, house = Json{Home}(house -> house.owner.age >= 25), msg = message; request, b = 3.0)
return a, house, msg
end

# Parse the function info
info = splitdef(f)

@testset "counts" begin
@test length(info.args) == 3
@test length(info.kwargs) == 2
@test length(info.sig) == 5
@test length(info.sig_map) == 5
end

@testset "Args" begin
@test info.args[1].name == :a
@test info.args[1].type == Int

@test info.args[2].name == :house
@test info.args[2].type == Json{Home}
@test info.args[2].default isa Json{Home}

@test info.args[3].name == :msg
@test info.args[3].type == Dict{String, String}
end

@testset "Kwargs" begin
@test info.kwargs[1].name == :request
@test info.kwargs[1].type == Any
@test info.kwargs[1].default isa Missing
@test info.kwargs[1].hasdefault == false

@test info.kwargs[2].name == :b
@test info.kwargs[2].type == Any
@test info.kwargs[2].default == 3.0
@test info.kwargs[2].hasdefault == true
end

@testset "Sig_map" begin
@test info.sig_map[:a].name == :a
@test info.sig_map[:a].type == Int
@test info.sig_map[:a].default isa Missing
@test info.sig_map[:a].hasdefault == false

@test info.sig_map[:house].name == :house
@test info.sig_map[:house].type == Json{Home}
@test info.sig_map[:house].default isa Json{Home}
@test info.sig_map[:house].hasdefault == true

@test info.sig_map[:msg].name == :msg
@test info.sig_map[:msg].type == Dict{String, String}
@test info.sig_map[:msg].default == Dict("message" => "Hello, World!")
@test info.sig_map[:msg].hasdefault == true

@test info.sig_map[:request].name == :request
@test info.sig_map[:request].type == Any
@test info.sig_map[:request].default isa Missing
@test info.sig_map[:request].hasdefault == false

@test info.sig_map[:b].name == :b
@test info.sig_map[:b].type == Any
@test info.sig_map[:b].default == 3.0
@test info.sig_map[:b].hasdefault == true
end
end


end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ include("constants.jl"); using .Constants
#### Extension Tests ####

include("extensions/templatingtests.jl")
include("extensions/protobuf/protobuftests.jl")
include("extensions/cairomakietests.jl")
include("extensions/wglmakietests.jl")
include("extensions/bonitotests.jl")
include("extensions/protobuf/protobuftests.jl")

#### Sepcial Handler Tests ####

Expand Down

0 comments on commit ce05310

Please sign in to comment.