diff --git a/Project.toml b/Project.toml index 71f0ca822..49207f55b 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = "Enzyme" +LinearSolveForwardDiff = "ForwardDiff" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -66,6 +68,7 @@ DocStringExtensions = "0.9" EnumX = "1" EnzymeCore = "0.6" FastLapackInterface = "2" +ForwardDiff = "0.10" GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6" diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl new file mode 100644 index 000000000..4b386889c --- /dev/null +++ b/ext/LinearSolveForwardDiff.jl @@ -0,0 +1,88 @@ +module LinearSolveForwardDiff + +using LinearSolve +using InteractiveUtils +isdefined(Base, :get_extension) ? + (import ForwardDiff; using ForwardDiff: Dual) : + (import ..ForwardDiff; using ..ForwardDiff: Dual) + +function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + @assert !(eltype(first(dAs)) isa Dual) + @assert !(eltype(first(dbs)) isa Dual) + @assert !(eltype(A) isa Dual) + @assert !(eltype(b) isa Dual) + reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol + abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol + u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u + cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval + cacheval = eltype(cacheval.factors) <: Dual ? begin + LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info) + end : cacheval + cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval + + cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) + res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy + dresus = reduce(hcat, map(dAs, dbs) do dA, db + cache2.b = db - dA * res.u + dres = LinearSolve.solve!(cache2, alg, kwargs...) + deepcopy(dres.u) + end) + d = Dual{T}.(res.u, Tuple.(eachrow(dresus))) + LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats) +end + + +for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) + @eval begin + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B}, + alg::$ALG, + kwargs... + ) where {T, V, P, B} + # @info "using solve! df/dA" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = [zero(cache.b) for _=1:P] + A = ForwardDiff.value.(cache.A) + b = cache.b + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P, A_} + # @info "using solve! df/db" + dAs = [zero(cache.A) for _=1:P] + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = cache.A + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P} + # @info "using solve! df/dAb" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = ForwardDiff.value.(cache.A) + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + end +end + +end # module \ No newline at end of file diff --git a/src/common.jl b/src/common.jl index 791ab91c8..dc8b748b3 100644 --- a/src/common.jl +++ b/src/common.jl @@ -82,6 +82,15 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} assumptions::OperatorAssumptions{issq} end +function SciMLBase.remake(cache::LinearCache; + A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg, + cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr, + abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters, + verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq} + LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol, + maxiters,verbose,assumptions) +end + function Base.setproperty!(cache::LinearCache, name::Symbol, x) if name === :A setfield!(cache, :isfresh, true) diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl new file mode 100644 index 000000000..61568d262 --- /dev/null +++ b/test/forwarddiff.jl @@ -0,0 +1,74 @@ +using Test +using ForwardDiff +using LinearSolve +using FiniteDiff +using Enzyme +using Random +Random.seed!(1234) + +n = 4 +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +for alg in ( + LUFactorization(), + RFLUFactorization(), + # KrylovJL_GMRES(), dispatch fails + ) + alg_str = string(alg) + @show alg_str + function fb(b) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fb(b1) + + fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fid_jac + + fod_jac = ForwardDiff.gradient(fb, b1) |> vec + @show fod_jac + + @test fod_jac ≈ fid_jac rtol=1e-6 + + function fA(A) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fA(A) + + fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fid_jac + + fod_jac = ForwardDiff.gradient(fA, A) |> vec + @show fod_jac + + @test fod_jac ≈ fid_jac rtol=1e-6 + + + function fAb(Ab) + A = Ab[:, 1:n] + b1 = Ab[:, n+1] + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fAb(hcat(A, b1)) + + fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec + @show fid_jac + + fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec + @show fod_jac + + @test fod_jac ≈ fid_jac rtol=1e-6 + +end \ No newline at end of file