From a9d63758b7ce9c3c94424a174cc3fd280de39f86 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Mon, 25 Nov 2024 14:37:39 +1300 Subject: [PATCH] Fix return type of `Bool` for ccall methods (#447) --- src/C_wrapper.jl | 270 ++++++++++++++++------------------------------ test/C_wrapper.jl | 119 ++++++++++++++++++++ 2 files changed, 212 insertions(+), 177 deletions(-) diff --git a/src/C_wrapper.jl b/src/C_wrapper.jl index 79aed6b..a59e05a 100644 --- a/src/C_wrapper.jl +++ b/src/C_wrapper.jl @@ -23,6 +23,8 @@ mutable struct IpoptProblem intermediate::Union{Function,Nothing} end +Base.unsafe_convert(::Type{Ptr{Cvoid}}, p::IpoptProblem) = p.ipopt_problem + function _Eval_F_CB( n::Cint, x_ptr::Ptr{Float64}, @@ -226,40 +228,22 @@ function CreateIpoptProblem( Ptr{Cvoid}, ), ) - ipopt_problem = ccall( - (:CreateIpoptProblem, libipopt), - Ptr{Cvoid}, - ( - Cint, - Ptr{Float64}, - Ptr{Float64}, - Cint, - Ptr{Float64}, - Ptr{Float64}, - Cint, - Cint, - Cint, - Ptr{Cvoid}, - Ptr{Cvoid}, - Ptr{Cvoid}, - Ptr{Cvoid}, - Ptr{Cvoid}, - ), - n, - x_L, - x_U, - m, - g_L, - g_U, - nele_jac, - nele_hess, - 1, # 1 = Fortran style indexing - eval_f_cb, - eval_g_cb, - eval_grad_f_cb, - eval_jac_g_cb, - eval_h_cb, - ) + ipopt_problem = @ccall libipopt.CreateIpoptProblem( + n::Cint, + x_L::Ptr{Cdouble}, + x_U::Ptr{Cdouble}, + m::Cint, + g_L::Ptr{Cdouble}, + g_U::Ptr{Cdouble}, + nele_jac::Cint, + nele_hess::Cint, + 1::Cint, # 1 = Fortran style indexing + eval_f_cb::Ptr{Cvoid}, + eval_g_cb::Ptr{Cvoid}, + eval_grad_f_cb::Ptr{Cvoid}, + eval_jac_g_cb::Ptr{Cvoid}, + eval_h_cb::Ptr{Cvoid}, + )::Ptr{Cvoid} if ipopt_problem == C_NULL if n == 0 error( @@ -294,12 +278,7 @@ function CreateIpoptProblem( end function FreeIpoptProblem(prob::IpoptProblem) - ccall( - (:FreeIpoptProblem, libipopt), - Cvoid, - (Ptr{Cvoid},), - prob.ipopt_problem, - ) + @ccall libipopt.FreeIpoptProblem(prob::Ptr{Cvoid})::Cvoid return end @@ -307,15 +286,12 @@ function AddIpoptStrOption(prob::IpoptProblem, keyword::String, value::String) if !(isascii(keyword) && isascii(value)) error("IPOPT: Non ASCII parameters not supported") end - ret = ccall( - (:AddIpoptStrOption, libipopt), - Cint, - (Ptr{Cvoid}, Ptr{UInt8}, Ptr{UInt8}), - prob.ipopt_problem, - keyword, - value, - ) - if ret == 0 + ret = @ccall libipopt.AddIpoptStrOption( + prob::Ptr{Cvoid}, + keyword::Ptr{UInt8}, + value::Ptr{UInt8}, + )::Bool + if !ret error("IPOPT: Couldn't set option '$keyword' to value '$value'.") end return @@ -325,15 +301,12 @@ function AddIpoptNumOption(prob::IpoptProblem, keyword::String, value::Float64) if !isascii(keyword) error("IPOPT: Non ASCII parameters not supported") end - ret = ccall( - (:AddIpoptNumOption, libipopt), - Cint, - (Ptr{Cvoid}, Ptr{UInt8}, Float64), - prob.ipopt_problem, - keyword, - value, - ) - if ret == 0 + ret = @ccall libipopt.AddIpoptNumOption( + prob::Ptr{Cvoid}, + keyword::Ptr{UInt8}, + value::Cdouble, + )::Bool + if !ret error("IPOPT: Couldn't set option '$keyword' to value '$value'.") end return @@ -343,15 +316,12 @@ function AddIpoptIntOption(prob::IpoptProblem, keyword::String, value::Integer) if !isascii(keyword) error("IPOPT: Non ASCII parameters not supported") end - ret = ccall( - (:AddIpoptIntOption, libipopt), - Cint, - (Ptr{Cvoid}, Ptr{UInt8}, Cint), - prob.ipopt_problem, - keyword, - value, - ) - if ret == 0 + ret = @ccall libipopt.AddIpoptIntOption( + prob::Ptr{Cvoid}, + keyword::Ptr{UInt8}, + value::Cint, + )::Bool + if !ret error( "IPOPT: Couldn't set option '$keyword' to value '$value'::Int32. " * "Note that `Num` options need to be explictly passed as " * @@ -369,15 +339,12 @@ function OpenIpoptOutputFile( if !isascii(file_name) error("IPOPT: Non ASCII parameters not supported") end - ret = ccall( - (:OpenIpoptOutputFile, libipopt), - Cint, - (Ptr{Cvoid}, Ptr{UInt8}, Cint), - prob.ipopt_problem, - file_name, - print_level, - ) - if ret == 0 + ret = @ccall libipopt.OpenIpoptOutputFile( + prob::Ptr{Cvoid}, + file_name::Ptr{UInt8}, + print_level::Cint, + )::Bool + if !ret error("IPOPT: Couldn't open output file.") end return @@ -389,18 +356,13 @@ function SetIpoptProblemScaling( x_scaling::Union{Ptr{Cvoid},Vector{Float64}}, g_scaling::Union{Ptr{Cvoid},Vector{Float64}}, ) - ret = ccall( - (:SetIpoptProblemScaling, libipopt), - Cint, - (Ptr{Cvoid}, Float64, Ptr{Float64}, Ptr{Float64}), - prob.ipopt_problem, - obj_scaling, - x_scaling, - g_scaling, - ) - if ret == 0 - error("IPOPT: Error setting problem scaling.") - end + ret = @ccall libipopt.SetIpoptProblemScaling( + prob::Ptr{Cvoid}, + obj_scaling::Cdouble, + x_scaling::Ptr{Cdouble}, + g_scaling::Ptr{Cdouble}, + )::Bool + @assert ret # The C++ code has `return true` return end @@ -423,46 +385,28 @@ function SetIntermediateCallback(prob::IpoptProblem, intermediate::Function) Ptr{Cvoid}, ), ) - ret = ccall( - (:SetIntermediateCallback, libipopt), - Cint, - (Ptr{Cvoid}, Ptr{Cvoid}), - prob.ipopt_problem, - intermediate_cb, - ) - if ret == 0 - error("IPOPT: Something went wrong setting the intermediate callback.") - end + ret = @ccall libipopt.SetIntermediateCallback( + prob::Ptr{Cvoid}, + intermediate_cb::Ptr{Cvoid}, + )::Bool + @assert ret # The C++ code has `return true` prob.intermediate = intermediate return end function IpoptSolve(prob::IpoptProblem) - final_objval = Ref{Cdouble}(0.0) - ret = ccall( - (:IpoptSolve, libipopt), - Cint, - ( - Ptr{Cvoid}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Cvoid}, - ), - prob.ipopt_problem, - prob.x, - prob.g, - final_objval, - prob.mult_g, - prob.mult_x_L, - prob.mult_x_U, - pointer_from_objref(prob), - ) - prob.obj_val = final_objval[] - prob.status = ret + p_objval = Ref{Cdouble}(0.0) + prob.status = @ccall libipopt.IpoptSolve( + prob::Ptr{Cvoid}, + prob.x::Ptr{Cdouble}, + prob.g::Ptr{Cdouble}, + p_objval::Ptr{Cdouble}, + prob.mult_g::Ptr{Cdouble}, + prob.mult_x_L::Ptr{Cdouble}, + prob.mult_x_U::Ptr{Cdouble}, + pointer_from_objref(prob)::Ptr{Cvoid}, + )::Cint + prob.obj_val = p_objval[] return prob.status end @@ -477,31 +421,18 @@ function GetIpoptCurrentIterate( g::Union{Ptr{Cvoid},Vector{Float64}}, lambda::Union{Ptr{Cvoid},Vector{Float64}}, ) - ret = ccall( - (:GetIpoptCurrentIterate, libipopt), - Cint, - ( - Ptr{Cvoid}, - Cint, - Cint, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Cint, - Ptr{Float64}, - Ptr{Float64}, - ), - prob.ipopt_problem, - scaled, - n, - x, - z_L, - z_U, - m, - g, - lambda, - ) - if ret == 0 + ret = @ccall libipopt.GetIpoptCurrentIterate( + prob::Ptr{Cvoid}, + scaled::Bool, + n::Cint, + x::Ptr{Cdouble}, + z_L::Ptr{Cdouble}, + z_U::Ptr{Cdouble}, + m::Cint, + g::Ptr{Cdouble}, + lambda::Ptr{Cdouble}, + )::Bool + if !ret error("IPOPT: Something went wrong getting the current iterate.") end return @@ -520,35 +451,20 @@ function GetIpoptCurrentViolations( nlp_constraint_violation::Union{Ptr{Cvoid},Vector{Float64}}, compl_g::Union{Ptr{Cvoid},Vector{Float64}}, ) - ret = ccall( - (:GetIpoptCurrentViolations, libipopt), - Cint, - ( - Ptr{Cvoid}, - Cint, - Cint, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Ptr{Float64}, - Cint, - Ptr{Float64}, - Ptr{Float64}, - ), - prob.ipopt_problem, - scaled, - n, - x_L_violation, - x_U_violation, - compl_x_L, - compl_x_U, - grad_lag_x, - m, - nlp_constraint_violation, - compl_g, - ) - if ret == 0 + ret = @ccall libipopt.GetIpoptCurrentViolations( + prob::Ptr{Cvoid}, + scaled::Bool, + n::Cint, + x_L_violation::Ptr{Cdouble}, + x_U_violation::Ptr{Cdouble}, + compl_x_L::Ptr{Cdouble}, + compl_x_U::Ptr{Cdouble}, + grad_lag_x::Ptr{Cdouble}, + m::Cint, + nlp_constraint_violation::Ptr{Cdouble}, + compl_g::Ptr{Cdouble}, + )::Bool + if !ret error("IPOPT: Something went wrong getting the current violations.") end return diff --git a/test/C_wrapper.jl b/test/C_wrapper.jl index 6ca7f51..f9d1b3a 100644 --- a/test/C_wrapper.jl +++ b/test/C_wrapper.jl @@ -8,6 +8,8 @@ module TestCWrapper using Ipopt using Test +import Ipopt_jll + function test_hs071() # hs071 # min x1 * x4 * (x1 + x2 + x3) + x3 @@ -338,6 +340,123 @@ function test_SetIpoptProblemScaling() return end +function test_OpenIpoptOutputFile() + prob = Ipopt.CreateIpoptProblem( + 1, # n, + [0.0], # x_L, + [1.0], # x_U, + 0, # m, + Float64[], # g_L, + Float64[], # g_U, + 0, # nele_jac, + 0, # nele_hess + x -> x[1], # eval_f, + (x, g) -> nothing, # eval_g, + (x, g) -> (g[1] = 1.0), # eval_grad_f, + (args...) -> nothing, # eval_jac_g, + nothing, + ) + @test_throws( + ErrorException("IPOPT: Couldn't open output file."), + Ipopt.OpenIpoptOutputFile(prob, "/illegal/bar.txt", 1), + ) + return +end + +function _ipopt_version() + io = IOBuffer() + run(pipeline(`$(Ipopt_jll.amplexe()) -v`; stdout = io)) + seekstart(io) + version = read(io, String) + m = match(r"Ipopt ([0-9]+.[0-9]+.[0-9]+)", version) + if m === nothing + return v"0.0.0" # Something went wrong + end + return VersionNumber(m[1]) +end + +function test_GetIpoptCurrentIterate() + if _ipopt_version() < v"3.14.12" + return # Bug fixed in 3.14.12 + end + prob = Ipopt.CreateIpoptProblem( + 1, # n, + [0.0], # x_L, + [1.0], # x_U, + 0, # m, + Float64[], # g_L, + Float64[], # g_U, + 0, # nele_jac, + 0, # nele_hess + x -> x[1], # eval_f, + (x, g) -> nothing, # eval_g, + (x, g) -> (g[1] = 1.0), # eval_grad_f, + (args...) -> nothing, # eval_jac_g, + nothing, + ) + x, z_L, z_U = zeros(1), zeros(1), zeros(1) + @test_throws( + ErrorException( + "IPOPT: Something went wrong getting the current iterate.", + ), + Ipopt.GetIpoptCurrentIterate( + prob, + false, + 1, + x, + z_L, + z_U, + 0, + Float64[], + Float64[], + ), + ) + return +end + +function test_GetIpoptCurrentViolations() + if _ipopt_version() < v"3.14.12" + return # Bug fixed in 3.14.12 + end + prob = Ipopt.CreateIpoptProblem( + 1, # n, + [0.0], # x_L, + [1.0], # x_U, + 0, # m, + Float64[], # g_L, + Float64[], # g_U, + 0, # nele_jac, + 0, # nele_hess + x -> x[1], # eval_f, + (x, g) -> nothing, # eval_g, + (x, g) -> (g[1] = 1.0), # eval_grad_f, + (args...) -> nothing, # eval_jac_g, + nothing, + ) + x_L, x_U = zeros(1), zeros(1) + comp_x_L, comp_x_U = zeros(1), zeros(1) + grad_lag_x = zeros(1) + @test_throws( + ErrorException( + "IPOPT: Something went wrong getting the current violations.", + ), + Ipopt.GetIpoptCurrentViolations( + prob, + false, + 1, + x_L, + x_U, + comp_x_L, + comp_x_U, + grad_lag_x, + 0, + Float64[], + Float64[], + ), + ) + return +end + end # TestCWrapper runtests(TestCWrapper)