Skip to content

Commit

Permalink
MutableArithmetics for IPM/HSD
Browse files Browse the repository at this point in the history
Use the `BigFloat` dot product from MutableArithmetics in HSD code.

Helps with the performance of the `BigFloat` arithmetic. The change
shouldn't affect other arithmetics, and it's coded so it'd be easy to
extend it to another mutable arithmetic apart from just `BigFloat`, if
necessary, and if such a type will support MutableArithmetics.

Apart from improving performance, this change could possibly also
benefit LP problems with numerical issues (when using `BigFloat`),
because the MA dot product uses a summation algorithm that's more
accurate than naive summation.

A performance experiment is presented in the commit message of the
following commit. The conclusion is that this commit improves
performance only by a tiny bit, likewise with allocation.
  • Loading branch information
nsajko committed Apr 20, 2023
1 parent 1021c8a commit 38a78c1
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QPSReader = "10f199a5-22af-520b-b891-7ce84a7b1bd0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -27,6 +28,7 @@ Krylov = "0.8, 0.9"
LDLFactorizations = "0.8, 0.9, 0.10"
LinearOperators = "2.0"
MathOptInterface = "1"
MutableArithmetics = "1.2"
QPSReader = "0.2"
TimerOutputs = "0.5.6"
julia = "1.6"
55 changes: 43 additions & 12 deletions src/IPM/HSD/HSD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ mutable struct HSD{T, Tv, Tb, Ta, Tk} <: AbstractIPMOptimizer{T}

end

include("dot_for_mutable.jl")

include("step.jl")


Expand Down Expand Up @@ -101,13 +103,22 @@ function compute_residuals!(hsd::HSD{T}
mul!(res.rd, transpose(dat.A), pt.y, -one(T), one(T))
@. res.rd += pt.zu .* dat.uflag - pt.zl .* dat.lflag

dot_buf = buffer_for_dot_weighted_sum(T)

# Gap residual
# rg = c'x - (b'y + l'zl - u'zu) + k
res.rg = pt.κ + (dot(dat.c, pt.x) - (
dot(dat.b, pt.y)
+ dot(dat.l .* dat.lflag, pt.zl)
- dot(dat.u .* dat.uflag, pt.zu)
))
res.rg = pt.κ + buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.c, pt.x),
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu),
),
(
1, -1, -1, 1,
),
)

# Residuals norm
res.rp_nrm = norm(res.rp, Inf)
Expand All @@ -117,11 +128,17 @@ function compute_residuals!(hsd::HSD{T}
res.rg_nrm = norm(res.rg, Inf)

# Compute primal and dual bounds
hsd.primal_objective = dot(dat.c, pt.x) / pt.τ + dat.c0
hsd.dual_objective = (
dot(dat.b, pt.y)
+ dot(dat.l .* dat.lflag, pt.zl)
- dot(dat.u .* dat.uflag, pt.zu)
hsd.primal_objective = buffered_dot_product!!(dot_buf.dot, dat.c, pt.x) / pt.τ + dat.c0
hsd.dual_objective = buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu),
),
(
1, 1, -1,
),
) / pt.τ + dat.c0

return nothing
Expand Down Expand Up @@ -168,12 +185,15 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher
return nothing
end

dot_buf = buffer_for_dot_weighted_sum(T)

# Check for infeasibility certificates
if max(
norm(dat.A * pt.x, Inf),
norm((pt.x .- pt.xl) .* dat.lflag, Inf),
norm((pt.x .+ pt.xu) .* dat.uflag, Inf)
) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) < - ϵi * dot(dat.c, pt.x)
) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) <
-ϵi * buffered_dot_product!!(dot_buf.dot, dat.c, pt.x)
# Dual infeasible, i.e., primal unbounded
hsd.primal_status = Sln_InfeasibilityCertificate
hsd.solver_status = Trm_DualInfeasible
Expand All @@ -185,7 +205,18 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher
norm(dat.l .* dat.lflag, Inf),
norm(dat.u .* dat.uflag, Inf),
norm(dat.b, Inf)
) / (max(one(T), norm(dat.c, Inf))) < (dot(dat.b, pt.y) + dot(dat.l .* dat.lflag, pt.zl)- dot(dat.u .* dat.uflag, pt.zu)) * ϵi
) / (max(one(T), norm(dat.c, Inf))) < buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu),
),
(
1, 1, -1,
),
) * ϵi

# Primal infeasible
hsd.dual_status = Sln_InfeasibilityCertificate
hsd.solver_status = Trm_PrimalInfeasible
Expand Down
102 changes: 102 additions & 0 deletions src/IPM/HSD/dot_for_mutable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Right now this is just `BigFloat`, but in principle it could be expanded to a whitelist
# that would include other mutable types.
const SupportedMutableArithmetics = BigFloat

buffer_for_dot_product(::Type{V}) where {V <: AbstractVector{<:Real}} =
buffer_for(LinearAlgebra.dot, V, V)

buffer_for_dot_product(::Type{F}) where {F <: Real} =
buffer_for_dot_product(Vector{F})

buffered_dot_product_to!(
buf::B,
result::F,
x::V,
y::V,
) where {B <: Any, F <: SupportedMutableArithmetics, V <: AbstractVector{F}} =
buffered_operate_to!(buf, result, LinearAlgebra.dot, x, y)

function buffered_dot_product!!(
buf::B,
x::V,
y::V,
) where {B <: Any, F <: SupportedMutableArithmetics, V <: AbstractVector{F}}
ret = zero(F)
ret = buffered_dot_product_to!(buf, ret, x, y)
return ret
end

buffered_dot_product!!(::Nothing, x::V, y::V) where {F <: Real, V <: AbstractVector{F}} =
dot(x, y)

struct DotWeightedSumBuffer{F <: Real, DotBuffer <: Any}
tmp::F
dot::DotBuffer

function DotWeightedSumBuffer{F}() where {F <: Real}
dot_buffer = buffer_for_dot_product(F)
return new{F, typeof(dot_buffer)}(zero(F), dot_buffer)
end
end

struct DotWeightedSumBufferDummy
dot::Nothing

DotWeightedSumBufferDummy() = new(nothing)
end

buffer_for_dot_weighted_sum(::Type{F}) where {F <: SupportedMutableArithmetics} =
DotWeightedSumBuffer{F}()

buffer_for_dot_weighted_sum(::Type{F}) where {F <: Real} =
DotWeightedSumBufferDummy()

function buffered_dot_weighted_sum_to_inner!(
buf::DotWeightedSumBuffer{F},
sum::F,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, <:Real},
) where {n, F <: SupportedMutableArithmetics}
sum = zero!!(sum)

for i in 1:n
weight = weights[i]
(x, y) = vecs[i]

buffered_dot_product_to!(buf.dot, buf.tmp, x, y)
mul!!(buf.tmp, weight)

sum = add!!(sum, buf.tmp)
end

return sum
end

buffered_dot_weighted_sum_to!(
buf::DotWeightedSumBuffer{F},
sum::F,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int}) where {n, F <: SupportedMutableArithmetics} =
# It seems like the specialization
# *(x::BigFloat, c::Int8)
# could be more efficient than
# *(x::BigFloat, c::Int)
# MPFR has separate functions for those, and Julia uses them,
# there must be a good (performance) reason for that.
buffered_dot_weighted_sum_to_inner!(buf, sum, vecs, map(Int8, weights))

function buffered_dot_weighted_sum!!(
buf::DotWeightedSumBuffer{F},
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int},
) where {n, F <: SupportedMutableArithmetics}
ret = zero(F)
ret = buffered_dot_weighted_sum_to!(buf, ret, vecs, weights)
return ret
end

buffered_dot_weighted_sum!!(
buf::DotWeightedSumBufferDummy,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int}) where {n, F <: Real} =
mapreduce((vec2, weight) -> weight*dot(vec2...), +, vecs, weights, init = zero(F))
66 changes: 46 additions & 20 deletions src/IPM/HSD/step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,23 @@ function compute_step!(hsd::HSD{T, Tv}, params::IPMOptions{T}) where{T, Tv<:Abst
ξ_ = @. (dat.c - ((pt.zl / pt.xl) * dat.l) * dat.lflag - ((pt.zu / pt.xu) * dat.u) * dat.uflag)
KKT.solve!(hx, hy, hsd.kkt, dat.b, ξ_)

dot_buf = buffer_for_dot_weighted_sum(T)

# Recover h0 = ρg + κ / τ - c'hx + b'hy - u'hz
# Some of the summands may take large values,
# so care must be taken for numerical stability
h0 = (
dot(dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag)
+ dot(dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag)
- dot((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx)
+ dot(b, hy)
+ pt.κ / pt.τ
+ hsd.regG
)
h0 = buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag),
(dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag),
((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx),
(b, hy),
),
(
1, 1, -1, 1,
),
) + pt.κ / pt.τ + hsd.regG

# Affine-scaling direction
@timeit hsd.timer "Newton" solve_newton_system!(Δ, hsd, hx, hy, h0,
Expand Down Expand Up @@ -211,22 +217,42 @@ function solve_newton_system!(Δ::Point{T, Tv},
end
@timeit hsd.timer "KKT" KKT.solve!.x, Δ.y, hsd.kkt, ξp, ξd_)

dot_buf = buffer_for_dot_weighted_sum(T)

# II. Recover Δτ, Δx, Δy
# Compute Δτ
@timeit hsd.timer "ξg_" ξg_ = (ξg + ξtk / pt.τ
- dot((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag) # l'(Xl)^-1 * ξxzl
+ dot((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag)
- dot(((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag)
- dot(((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag) #
)
@timeit hsd.timer "ξg_" ξg_ = ξg + ξtk / pt.τ +
buffered_dot_weighted_sum!!(
dot_buf,
(
((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag), # l'(Xl)^-1 * ξxzl
((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag),
(((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag),
(((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag),
),
(
-1, 1, -1, -1,
),
)

@timeit hsd.timer "Δτ" Δ.τ = (
ξg_
+ dot((@. (dat.c
+ ((pt.zl / pt.xl) * dat.l) * dat.lflag
+ ((pt.zu / pt.xu) * dat.u) * dat.uflag))
, Δ.x)
- dot(dat.b, Δ.y)
ξg_ +
buffered_dot_weighted_sum!!(
dot_buf,
(
(
(@. (
dat.c +
((pt.zl / pt.xl) * dat.l) * dat.lflag +
((pt.zu / pt.xu) * dat.u) * dat.uflag)),
Δ.x,
),
(dat.b, Δ.y),
),
(
1, -1,
),
)
) / h0


Expand Down
1 change: 1 addition & 0 deletions src/Tulip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Tulip

using LinearAlgebra
using Logging
using MutableArithmetics
using Printf
using SparseArrays
using TOML
Expand Down

0 comments on commit 38a78c1

Please sign in to comment.