From 8c7b8f19744bb0d8aec7a7490a647cd9f759767a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 17 Aug 2024 11:34:12 -0400 Subject: [PATCH 1/5] WIP: add Enzyme support for fastpow Straightforward since fastpow is simply ^. Still needs: - [ ] Tests - [ ] Generalize to batchduplicated --- ext/DiffEqBaseEnzymeExt.jl | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index 2b2b7c001..fc9d2041b 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -1,7 +1,7 @@ module DiffEqBaseEnzymeExt using DiffEqBase -import DiffEqBase: value +import DiffEqBase: value, fastpow using Enzyme import Enzyme: Const using ChainRulesCore @@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, return ntuple(_ -> nothing, Val(length(args) + 4)) end +function EnzymeRules.forward(func::Const{typeof(fastpow)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicated,BatchDuplicatedNoNeed}}, + _x::Annotation, _y::Annotation) + x = _x.val + y = _y.val + ret = func.val(x.val, y.val) + dxval = x.dval * y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + return Duplicated(ret, dxval + dyval) +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, + func::Const{typeof(fastpow)}, + ::Type{<:Active}, + x::Active, x::Active) + if EnzymeRules.needs_primal(config) + primal = func.val(x.val, y.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, + func::Const{DiffEqBase.fastpow}, dret, tape::Nothing, + _x, _y) + x = _x.val + y = _y.val + dxval = x.dval * y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + return (dxval, dyval) +end + end From 7c67680d91f5941ec25c20ef4abe0f2e2b3155cd Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 24 Aug 2024 16:15:54 -0600 Subject: [PATCH 2/5] tests for enzyme fastpow rule --- test/fastpow.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/fastpow.jl b/test/fastpow.jl index 96b588aa1..fa8918ae0 100644 --- a/test/fastpow.jl +++ b/test/fastpow.jl @@ -1,4 +1,5 @@ using DiffEqBase: fastlog2, _exp2, fastpow +using Enzyme, EnzymeTestUtils using Test @testset "Fast log2" begin @@ -19,3 +20,23 @@ end errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5] @test maximum(errors) < 1e-4 end + +@testset "Fast pow - Enzyme forward rule" begin + @testset for RT in (Duplicated, DuplicatedNoNeed), + Tx in (Const, Duplicated), + Ty in (Const, Duplicated) + x = 3.0 + y = 2.0 + test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.1, rtol=0.1) + end +end + +@testset "Fast pow - Enzyme reverse rule" begin + @testset for RT in (Active,), + Tx in (Active,), + Ty in (Active,) + x = 2.0 + y = 3.0 + test_reverse(fastpow, RT, (x, Tx), (y, Ty)) + end +end \ No newline at end of file From 530d113f69e6f8b69865607e2951f798e8c908eb Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 24 Aug 2024 16:16:48 -0600 Subject: [PATCH 3/5] wip: rules are hit --- ext/DiffEqBaseEnzymeExt.jl | 44 ++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index fc9d2041b..43e6030c3 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -53,22 +53,31 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, return ntuple(_ -> nothing, Val(length(args) + 4)) end -function EnzymeRules.forward(func::Const{typeof(fastpow)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicated,BatchDuplicatedNoNeed}}, - _x::Annotation, _y::Annotation) +function Enzyme.EnzymeRules.forward(func::Const{typeof(DiffEqBase.fastpow)}, + RT::Type{<:Union{Duplicated, DuplicatedNoNeed}}, + _x::Union{Const, Duplicated}, _y::Union{Const, Duplicated}) x = _x.val y = _y.val - ret = func.val(x.val, y.val) - dxval = x.dval * y * (fastpow(x,y - 1)) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) - return Duplicated(ret, dxval + dyval) + ret = func.val(x, y) + if !(_x isa Const) + dxval = _x.dval * y * (fastpow(x,y - 1)) + else + dxval = make_zero(_x.val) + end + if !(_y isa Const) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpow(x,y))*log(x) + else + dyval = make_zero(_y.val) + end + if RT <: DuplicatedNoNeed + return Float32(dxval + dyval) + else + return Duplicated(ret, Float32(dxval + dyval)) + end end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, - func::Const{typeof(fastpow)}, - ::Type{<:Active}, - x::Active, x::Active) +function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(fastpow)}, ::Type{<:Active}, x::Active, y::Active) if EnzymeRules.needs_primal(config) primal = func.val(x.val, y.val) else @@ -77,14 +86,13 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, - func::Const{DiffEqBase.fastpow}, dret, tape::Nothing, - _x, _y) +function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active) x = _x.val y = _y.val - dxval = x.dval * y * (fastpow(x,y - 1)) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + dxval = y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : (fastpow(x,y))*log(x) return (dxval, dyval) end -end +end \ No newline at end of file From 03478e29b9982e18fe9075bf65fde986bf06621e Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sun, 25 Aug 2024 18:23:43 -0600 Subject: [PATCH 4/5] multiply by dret --- ext/DiffEqBaseEnzymeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index 43e6030c3..50ea4f9f4 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -90,8 +90,8 @@ function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active) x = _x.val y = _y.val - dxval = y * (fastpow(x,y - 1)) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : (fastpow(x,y))*log(x) + dxval = dret.val * y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : dret.val * (fastpow(x,y))*log(x) return (dxval, dyval) end From 61fc1b109198f6308d3d862a99882a2f65225f16 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sun, 25 Aug 2024 18:24:01 -0600 Subject: [PATCH 5/5] tighten tolerances (test passes) --- test/fastpow.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fastpow.jl b/test/fastpow.jl index fa8918ae0..360fc89b0 100644 --- a/test/fastpow.jl +++ b/test/fastpow.jl @@ -27,7 +27,7 @@ end Ty in (Const, Duplicated) x = 3.0 y = 2.0 - test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.1, rtol=0.1) + test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005) end end @@ -37,6 +37,6 @@ end Ty in (Active,) x = 2.0 y = 3.0 - test_reverse(fastpow, RT, (x, Tx), (y, Ty)) + test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001) end end \ No newline at end of file