-
-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: add Enzyme support for fastpow #1073
Conversation
Straightforward since fastpow is simply ^. Still needs: - [ ] Tests - [ ] Generalize to batchduplicated
@@ -37,6 +37,6 @@ end | |||
Ty in (Active,) | |||
x = 2.0 | |||
y = 3.0 | |||
test_reverse(fastpow, RT, (x, Tx), (y, Ty)) | |||
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be like, atol=1e-10
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005) | ||
end | ||
end | ||
|
||
@testset "Fast pow - Enzyme reverse rule" begin | ||
@testset for RT in (Active,), | ||
Tx in (Active,), | ||
Ty in (Active,) | ||
x = 2.0 | ||
y = 3.0 | ||
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005) | |
end | |
end | |
@testset "Fast pow - Enzyme reverse rule" begin | |
@testset for RT in (Active,), | |
Tx in (Active,), | |
Ty in (Active,) | |
x = 2.0 | |
y = 3.0 | |
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001) | |
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=1e-10) | |
end | |
end | |
@testset "Fast pow - Enzyme reverse rule" begin | |
@testset for RT in (Active,), | |
Tx in (Active,), | |
Ty in (Active,) | |
x = 2.0 | |
y = 3.0 | |
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=1e-10) |
if this is just pow you can probably do somethng like https://github.com/EnzymeAD/Enzyme.jl/blob/44febc52cbc7b154900cc5afd846e658d483e931/ext/EnzymeLogExpFunctionsExt.jl#L7 e.g.
whihc will auto register all the rules and other optimizations by marking it pow-like |
So this would replace all of the custom rules? Or in addition? |
It would replace all of these rules
…On Mon, Aug 26, 2024 at 9:45 AM Matt Bossart ***@***.***> wrote:
whihc will auto register all the rules and other optimizations by marking
it pow-like
So this would replace all of the custom rules? Or in addition?
—
Reply to this email directly, view it on GitHub
<#1073 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXHIWGAQCMEBUPYEY7DZTM5RPAVCNFSM6AAAAABNB6GCPWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMJQGM4TIMZVGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Oh I didn't know that feature existed. Yes, it's just a different precision pow, so if this just replaced pow with fastpow everywhere to construct the rule then that is exactly what we need |
When I try this Julia is crashing when I run autodiff. MWE:
|
How does it crash? |
Assertion failed: lhs_ty == rhs_ty, file /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp, line 5045 [20916] signal (22): SIGABRT |
Does fastpow take different types inputs?
Also to be clear this registration really uses enzyme internals so there be
dragons if applying
…On Mon, Aug 26, 2024 at 10:37 AM Matt Bossart ***@***.***> wrote:
Assertion failed: lhs_ty == rhs_ty, file
/workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp, line 5045
[20916] signal (22): SIGABRT
in expression starting at C:\Users\Matt Bossart\OneDrive -
UCB-O365\Desktop\DiffEqBase.jl\test\mwe.jl:7
crt_sig_handler at C:/workdir/src\signals-win.c:95
raise at C:\WINDOWS\System32\msvcrt.dll (unknown line)
abort at C:\WINDOWS\System32\msvcrt.dll (unknown line)
assert at C:\WINDOWS\System32\msvcrt.dll (unknown line)
recursiveFAdd at
/workspace/srcdir/Enzyme/enzyme/Enzyme\GradientUtils.cpp:5045
handleAdjointForIntrinsic at
/workspace/srcdir/Enzyme/build/Enzyme\IntrinsicDerivatives.inc:3240
handleKnownCallDerivatives at
/workspace/srcdir/Enzyme/enzyme/Enzyme\CallDerivatives.cpp:2866
visitCallInst at
/workspace/srcdir/Enzyme/enzyme/Enzyme\AdjointGenerator.h:6336
visit at
/opt/x86_64-w64-mingw32/x86_64-w64-mingw32/sys-root/usr/local/include/llvm/IR\InstVisitor.h:111
[inlined]
CreateForwardDiff at
/workspace/srcdir/Enzyme/enzyme/Enzyme\EnzymeLogic.cpp:4941
EnzymeCreateForwardDiff at
/workspace/srcdir/Enzyme/enzyme/Enzyme\CApi.cpp:602
EnzymeCreateForwardDiff at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\api.jl:177
unknown function (ip: 0000022e1e842669)
enzyme! at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:4064
unknown function (ip: 0000022e1e840e7b)
#codegen#18952 at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:6302
codegen at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:5493 [inlined]
_thunk at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:7103
_thunk at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:7103 [inlined]
cached_compilation at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:7144 [inlined]
thunkbase at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:7217
unknown function (ip: 0000022e1e85c10f)
#s2048#18999 at C:\Users\Matt
Bossart.julia\packages\Enzyme\XGb4o\src\compiler.jl:7269 [inlined]
#s2048#18999 at .\none:0
GeneratedFunctionStub at .\boot.jl:602
jl_call_staged at C:/workdir/src\method.c:540
ijl_code_for_staged at C:/workdir/src\method.c:593
get_staged at .\compiler\utilities.jl:123
retrieve_code_info at .\compiler\utilities.jl:135 [inlined]
InferenceState at .\compiler\inferencestate.jl:430
typeinf_edge at .\compiler\typeinfer.jl:920
abstract_call_method at .\compiler\abstractinterpretation.jl:629
abstract_call_gf_by_type at .\compiler\abstractinterpretation.jl:95
abstract_call_known at .\compiler\abstractinterpretation.jl:2087
abstract_call at .\compiler\abstractinterpretation.jl:2169
abstract_call at .\compiler\abstractinterpretation.jl:2162
abstract_call at .\compiler\abstractinterpretation.jl:2354
abstract_eval_call at .\compiler\abstractinterpretation.jl:2370
abstract_eval_statement_expr at .\compiler\abstractinterpretation.jl:2380
abstract_eval_statement at .\compiler\abstractinterpretation.jl:2624
abstract_eval_basic_statement at .\compiler\abstractinterpretation.jl:2889
typeinf_local at .\compiler\abstractinterpretation.jl:3098
typeinf_nocycle at .\compiler\abstractinterpretation.jl:3186
_typeinf at .\compiler\typeinfer.jl:247
typeinf at .\compiler\typeinfer.jl:216
typeinf_edge at .\compiler\typeinfer.jl:930
abstract_call_method at .\compiler\abstractinterpretation.jl:629
abstract_call_gf_by_type at .\compiler\abstractinterpretation.jl:95
abstract_call_known at .\compiler\abstractinterpretation.jl:2087
abstract_call at .\compiler\abstractinterpretation.jl:2169
abstract_apply at .\compiler\abstractinterpretation.jl:1612
abstract_call_known at .\compiler\abstractinterpretation.jl:2004
abstract_call at .\compiler\abstractinterpretation.jl:2169
abstract_call at .\compiler\abstractinterpretation.jl:2162
abstract_call at .\compiler\abstractinterpretation.jl:2354
abstract_eval_call at .\compiler\abstractinterpretation.jl:2370
abstract_eval_statement_expr at .\compiler\abstractinterpretation.jl:2380
abstract_eval_statement at .\compiler\abstractinterpretation.jl:2624
abstract_eval_basic_statement at .\compiler\abstractinterpretation.jl:2913
typeinf_local at .\compiler\abstractinterpretation.jl:3098
typeinf_nocycle at .\compiler\abstractinterpretation.jl:3186
_typeinf at .\compiler\typeinfer.jl:247
typeinf at .\compiler\typeinfer.jl:216
typeinf_edge at .\compiler\typeinfer.jl:930
abstract_call_method at .\compiler\abstractinterpretation.jl:629
abstract_call_gf_by_type at .\compiler\abstractinterpretation.jl:95
abstract_call_known at .\compiler\abstractinterpretation.jl:2087
abstract_call at .\compiler\abstractinterpretation.jl:2169
abstract_apply at .\compiler\abstractinterpretation.jl:1612
abstract_call_known at .\compiler\abstractinterpretation.jl:2004
abstract_call at .\compiler\abstractinterpretation.jl:2169
abstract_call at .\compiler\abstractinterpretation.jl:2162
abstract_call at .\compiler\abstractinterpretation.jl:2354
abstract_eval_call at .\compiler\abstractinterpretation.jl:2370
abstract_eval_statement_expr at .\compiler\abstractinterpretation.jl:2380
abstract_eval_statement at .\compiler\abstractinterpretation.jl:2624
abstract_eval_basic_statement at .\compiler\abstractinterpretation.jl:2913
typeinf_local at .\compiler\abstractinterpretation.jl:3098
typeinf_nocycle at .\compiler\abstractinterpretation.jl:3186
_typeinf at .\compiler\typeinfer.jl:247
typeinf at .\compiler\typeinfer.jl:216
typeinf_ext at .\compiler\typeinfer.jl:1051
typeinf_ext_toplevel at .\compiler\typeinfer.jl:1082
typeinf_ext_toplevel at .\compiler\typeinfer.jl:1078
jfptr_typeinf_ext_toplevel_38981.1 at C:\Users\Matt
Bossart.julia\juliaup\julia-1.10.4+0.x64.w64.mingw32\lib\julia\sys.dll
(unknown line)
_jl_invoke at C:/workdir/src\gf.c:2895 [inlined]
ijl_apply_generic at C:/workdir/src\gf.c:3077 [inlined]
jl_apply at C:/workdir/src\julia.h:1982 [inlined]
jl_type_infer at C:/workdir/src\gf.c:394
jl_generate_fptr_impl at C:/workdir/src\jitlayers.cpp:504
jl_compile_method_internal at C:/workdir/src\gf.c:2481
jl_compile_method_internal at C:/workdir/src\gf.c:2372 [inlined]
_jl_invoke at C:/workdir/src\gf.c:2887 [inlined]
ijl_apply_generic at C:/workdir/src\gf.c:3077
jl_apply at C:/workdir/src\julia.h:1982 [inlined]
do_call at C:/workdir/src\interpreter.c:126
eval_value at C:/workdir/src\interpreter.c:223
eval_stmt_value at C:/workdir/src\interpreter.c:174 [inlined]
eval_body at C:/workdir/src\interpreter.c:635
jl_interpret_toplevel_thunk at C:/workdir/src\interpreter.c:775
jl_toplevel_eval_flex at C:/workdir/src\toplevel.c:934
jl_toplevel_eval_flex at C:/workdir/src\toplevel.c:877
ijl_toplevel_eval at C:/workdir/src\toplevel.c:943 [inlined]
ijl_toplevel_eval_in at C:/workdir/src\toplevel.c:985
eval at .\boot.jl:385 [inlined]
include_string at .\loading.jl:2076
_include at .\loading.jl:2136
include at .\client.jl:489
unknown function (ip: 0000022e1e773d3b)
jl_apply at C:/workdir/src\julia.h:1982 [inlined]
do_call at C:/workdir/src\interpreter.c:126
eval_value at C:/workdir/src\interpreter.c:223
eval_stmt_value at C:/workdir/src\interpreter.c:174 [inlined]
eval_body at C:/workdir/src\interpreter.c:635
jl_interpret_toplevel_thunk at C:/workdir/src\interpreter.c:775
jl_toplevel_eval_flex at C:/workdir/src\toplevel.c:934
jl_toplevel_eval_flex at C:/workdir/src\toplevel.c:877
ijl_toplevel_eval at C:/workdir/src\toplevel.c:943 [inlined]
ijl_toplevel_eval_in at C:/workdir/src\toplevel.c:985
eval at .\boot.jl:385 [inlined]
eval_user_input at
C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:150
repl_backend_loop at
C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:246
#start_repl_backend#46 at
C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:231
start_repl_backend at
C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:228
#run_repl#59 at
C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:389
run_repl at C:\workdir\usr\share\julia\stdlib\v1.10\REPL\src\REPL.jl:375
jfptr_run_repl_95791.1 at C:\Users\Matt
Bossart.julia\juliaup\julia-1.10.4+0.x64.w64.mingw32\lib\julia\sys.dll
(unknown line)
#1013 <#1013> at
.\client.jl:432
jfptr_YY.1013_86566.1 at C:\Users\Matt
Bossart.julia\juliaup\julia-1.10.4+0.x64.w64.mingw32\lib\julia\sys.dll
(unknown line)
jl_apply at C:/workdir/src\julia.h:1982 [inlined]
jl_f__call_latest at C:/workdir/src\builtins.c:812
#invokelatest#2 at .\essentials.jl:892 [inlined]
invokelatest at .\essentials.jl:889 [inlined]
run_main_repl at .\client.jl:416
exec_options at .\client.jl:333
_start at .\client.jl:552
jfptr__start_86591.1 at C:\Users\Matt
Bossart.julia\juliaup\julia-1.10.4+0.x64.w64.mingw32\lib\julia\sys.dll
(unknown line)
jl_apply at C:/workdir/src\julia.h:1982 [inlined]
true_main at C:/workdir/src\jlapi.c:582
jl_repl_entrypoint at C:/workdir/src\jlapi.c:731
mainCRTStartup at C:/workdir/cli\loader_exe.c:58
BaseThreadInitThunk at C:\WINDOWS\System32\KERNEL32.DLL (unknown line)
RtlUserThreadStart at C:\WINDOWS\SYSTEM32\ntdll.dll (unknown line)
Allocations: 18575851 (Pool: 18539513; Big: 36338); GC: 25
—
Reply to this email directly, view it on GitHub
<#1073 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXCISI2YLN2QYMMWY5LZTNDVBAVCNFSM6AAAAABNB6GCPWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMJQGUYDIOBSGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
fastpow takes real inputs and returns Float32: Line 100 in 13ac2da
|
Is there a way to add a "if it's in an enzyme context, then convert to |
Completed in #1072 |
Still needs: