Skip to content

Commit

Permalink
Detect Base method overrides (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Feb 1, 2024
1 parent 91d4502 commit 935c8fe
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- "emb3large" refers to the large version of the new embedding model (dim=3072), which is only 30% more expensive than Ada
- Improved AgentTools: added more information and specific methods to `aicode_feedback` and `error_feedback` to pass more targeted feedback/tips to the AIAgent
- Improved detection of which lines were the source of error during `AICode` evaluation + forcing the error details to be printed in `AICode(...).stdout` for downstream analysis.
- Improved detection of Base/Main method overrides in `AICode` evaluation (only warns about the fact), but you can use `detect_base_main_overrides(code)` for custom handling

### Fixed
- Fixed typos in the documentation
Expand Down
5 changes: 5 additions & 0 deletions src/code_eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ function eval!(cb::AbstractCodeBlock;
(cb.error = ErrorException("Safety Error: Failed package import. Missing packages: $(join(string.(missing_packages),", ")). Please add them or disable the safety check (`safe_eval=false`)"))
return cb
end
detected, overrides = detect_base_main_overrides(code)
if detected
## DO NOT THROW ERROR
@warn "Safety Warning: Base / Main overrides detected (functions: $(join(overrides,",")))! Please verify the safety of the code or disable the safety check (`safe_eval=false`)"
end
end
## Catch bad code extraction
if isempty(code)
Expand Down
78 changes: 72 additions & 6 deletions src/code_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ function detect_pkg_operation(input::AbstractString)
return !isnothing(m)
end
# Utility to detect dependencies in a string (for `safe` code evaluation / understand when we don't have a necessary package)
function extract_julia_imports(input::AbstractString)
"""
extract_julia_imports(input::AbstractString; base_or_main::Bool = false)
Detects any `using` or `import` statements in a given string and returns the package names as a vector of symbols.
`base_or_main` is a boolean that determines whether to isolate only `Base` and `Main` OR whether to exclude them in the returned vector.
"""
function extract_julia_imports(input::AbstractString; base_or_main::Bool = false)
package_names = Symbol[]
for line in split(input, "\n")
if occursin(r"(^using |^import )"m, line)
Expand All @@ -29,9 +36,16 @@ function extract_julia_imports(input::AbstractString)
subparts = map(x -> contains(x, ':') ? split(x, ':')[1] : x,
split(subparts, ","))
subparts = replace(join(subparts, ' '), ',' => ' ')
packages = filter(x -> !isempty(x) && !startswith(x, "Base") &&
!startswith(x, "Main"),
split(subparts, " "))
packages = filter(x -> !isempty(x), split(subparts, " "))
if base_or_main
## keep only them
packages = filter(x -> startswith(x, "Base") ||
startswith(x, "Main"), packages)
else
## exclude them
packages = filter(x -> !startswith(x, "Base") &&
!startswith(x, "Main"), packages)
end
append!(package_names, Symbol.(packages))
end
end
Expand Down Expand Up @@ -336,6 +350,8 @@ Extract the name of a function from a given Julia code block. The function searc
If a function name is found, it is returned as a string. If no function name is found, the function returns `nothing`.
To capture all function names in the block, use `extract_function_names`.
# Arguments
- `code_block::String`: A string containing Julia code.
Expand All @@ -355,9 +371,9 @@ extract_function_name(code)
"""
function extract_function_name(code_block::AbstractString)
# Regular expression for the explicit function declaration
pattern_explicit = r"function\s+(\w+)\("
pattern_explicit = r"^\s*function\s+([\w\.\_]+)\("m
# Regular expression for the concise function declaration
pattern_concise = r"^(\w+)\(.*\)\s*="
pattern_concise = r"^\s*([\w\.\_]+)\(.*\)\s*="m

# Searching for the explicit function declaration
match_explicit = match(pattern_explicit, code_block)
Expand All @@ -375,6 +391,56 @@ function extract_function_name(code_block::AbstractString)
return nothing
end

"""
extract_function_names(code_block::AbstractString)
Extract one or more names of functions defined in a given Julia code block. The function searches for two patterns:
- The explicit function declaration pattern: `function name(...) ... end`
- The concise function declaration pattern: `name(...) = ...`
It always returns a vector of strings, even if only one function name is found (it will be empty).
For only one function name match, use `extract_function_name`.
"""
function extract_function_names(code_block::AbstractString)
# Regular expression for the explicit function declaration
pattern_explicit = r"^\s*function\s+([\w\.\_]+)\("m
# Regular expression for the concise function declaration
pattern_concise = r"^\s*([\w\.\_]+)\(.*\)\s*="m

matches = String[]

# Searching for the explicit function declaration
for m in eachmatch(pattern_explicit, code_block)
push!(matches, m.captures[1])
end
# Searching for the concise function declaration
for m in eachmatch(pattern_concise, code_block)
push!(matches, m.captures[1])
end

return matches
end

"""
detect_base_main_overrides(code_block::AbstractString)
Detects if a given code block overrides any Base or Main methods.
Returns a tuple of a boolean and a vector of the overriden methods.
"""
function detect_base_main_overrides(code_block::AbstractString)
funcs = extract_function_names(code_block)
base_imports = extract_julia_imports(code_block; base_or_main = true) .|>
x -> split(string(x), ".")[end]
## check Base/Main method overrides
overriden_methods = filter(f -> occursin("Base.", f) || occursin("Main.", f) ||
in(f, base_imports),
funcs)
detected = !isempty(overriden_methods)
return detected, overriden_methods
end

function extract_testset_name(testset_str::AbstractString)
# Define a regex pattern to match the function name
pattern = r"^\s*@testset\s*\"([^\"]+)\"\s* begin"ms
Expand Down
9 changes: 9 additions & 0 deletions test/code_eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ end
@test cb.error isa Exception
@test occursin("Safety Error", cb.error.msg)
@test occursin("Use of package manager ", cb.error.msg)
## Base / Main overrides
cb = AICode(; code = """
import Base.splitx
splitx(aaa) = 2
""")
@test_logs (:warn,
r"Safety Warning: Base / Main overrides detected \(functions: splitx\)") match_mode=:any eval!(cb;
safe_eval = true)

# Evaluate inside a gensym'd module
cb = AICode(; code = "a=1") |> eval!
Expand Down
68 changes: 66 additions & 2 deletions test/code_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using PromptingTools: extract_julia_imports
using PromptingTools: detect_pkg_operation,
detect_missing_packages, extract_function_name, remove_unsafe_lines
detect_missing_packages, extract_function_name, extract_function_names,
remove_unsafe_lines, detect_base_main_overrides
using PromptingTools: has_julia_prompt,
remove_julia_prompt, extract_code_blocks, extract_code_blocks_fallback, eval!
using PromptingTools: escape_interpolation, find_subsequence_positions
Expand Down Expand Up @@ -42,6 +43,8 @@ end
Symbol.(["PackageA.PackageB", "PackageC"])
@test extract_julia_imports("using Base.Threads\nusing Main.MyPkg") ==
Symbol[]
@test extract_julia_imports("using Base.Threads\nusing Main.MyPkg";
base_or_main = true) == Symbol[Symbol("Base.Threads"), Symbol("Main.MyPkg")]
end

@testset "detect_missing_packages" begin
Expand Down Expand Up @@ -295,9 +298,9 @@ end
@testset "extract_function_name" begin
# Test 1: Test an explicit function declaration
@test extract_function_name("function testFunction1()\nend") == "testFunction1"

# Test 2: Test a concise function declaration
@test extract_function_name("testFunction2() = 42") == "testFunction2"
@test extract_function_name(" test_Function_2() = 42") == "test_Function_2"

# Test 3: Test a code block with no function
@test extract_function_name("let a = 10\nb = 20\nend") === nothing
Expand All @@ -321,6 +324,67 @@ end
""") == "firstFunction"
end

@testset "extract_function_names" begin
code_block = """
function add(x, y)
return x + y
end
subtract(x, y) = x - y
"""
expected_result = ["add", "subtract"]
@test extract_function_names(code_block) == expected_result

s = """
import Base.splitx
Base.splitx()=1
splitx(aaa) = 2
"""
@test extract_function_names(s) == ["Base.splitx", "splitx"]
@test extract_function_names("") == String[]
end

@testset "detect_base_main_overrides" begin
# Test case 1: No overrides detected
code_block_1 = """
function foo()
println("Hello, World!")
end
"""
@test detect_base_main_overrides(code_block_1) == (false, [])

# Test case 2: Overrides detected
code_block_2 = """
function Base.bar()
println("Override Base.bar()")
end
function Main.baz()
println("Override Main.baz()")
end
"""
@test detect_base_main_overrides(code_block_2) == (true, ["Base.bar", "Main.baz"])

# Test case 3: Overrides with base imports
code_block_3 = """
using Base: sin
function Main.qux()
println("Override Main.qux()")
end
"""
@test detect_base_main_overrides(code_block_3) == (true, ["Main.qux"])

s4 = """
import Base.splitx
splitx(aaa) = 2
"""
@test detect_base_main_overrides(s4) == (true, ["splitx"])
end

@testset "extract_testset_name" begin
@test extract_testset_name("@testset \"TestSet1\" begin") == "TestSet1"
testset_str = """
Expand Down

0 comments on commit 935c8fe

Please sign in to comment.