Skip to content

Commit

Permalink
add mpi mode
Browse files Browse the repository at this point in the history
  • Loading branch information
AStupidBear committed Jan 19, 2020
1 parent 96955a3 commit d854b8c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 20 deletions.
14 changes: 8 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GCMAES"
uuid = "4aa9d100-eb0f-11e8-15f1-25748831eb3b"
authors = ["Yao Lu <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand All @@ -10,16 +10,18 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[compat]
BSON = "0.2.4"
ForwardDiff = "0.10"
Requires = "0.5"
julia = "1"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ForwardDiff"]
test = ["Test", "ForwardDiff"]
33 changes: 20 additions & 13 deletions src/GCMAES.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@ __precompile__(true)

module GCMAES

using BSON, Printf, Distributed, LinearAlgebra, Dates, Random, Statistics
using Printf, Distributed, LinearAlgebra
using Dates, Random, Statistics
using Requires, BSON

include("util.jl")
include("constraint.jl")

function __init__()
@require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" include("mpi.jl")
end

mutable struct CMAESOpt{T, F, G, S}
# fixed hyper-parameters
f::F
Expand Down Expand Up @@ -85,7 +91,7 @@ function CMAESOpt(f, g, x0, σ0, lo = -fill(1, size(x0)), hi = fill(1, size(x0))
# init a few things
arx, ary, arz = zeros(N, λ), zeros(N, λ), zeros(N, λ)
arfitness, arpenalty, arindex = zeros(λ), zeros(λ), ones(λ)
@printf("%i-%i CMA-ES\n", λ, μ)
@master @printf("%i-%i CMA-ES\n", λ, μ)
# gradient
T, F, G, S = eltype(x0), typeof(f), typeof(g), typeof(constr)
return CMAESOpt{T, F, G, S}(
Expand Down Expand Up @@ -124,7 +130,7 @@ function linesearch(f, x0::Array{T}, Δ) where T
nrm = norm(Δ)
nrm == 0 && return x0, typemax(T), zero(T)
rmul!(Δ, 1 / nrm)
αs = T[0.0; 2.0.^(2 - nworkers():0)]
αs = T[0.0; 2.0.^(2 - worldsize():0)]
xs = [x0 .+ α .* Δ for α in αs]
fs = pmap(f, xs)
fx, i = findmin(fs)
Expand Down Expand Up @@ -179,7 +185,7 @@ function terminate(opt::CMAESOpt)
# FlatFitness: warn if 70% candidates' fitnesses are identical
if opt.arfitness[1] == opt.arfitness[ceil(Int, 0.7opt.λ)]
opt.σ *= exp(0.2 + opt./ opt.dσ)
println("warning: flat fitness, consider reformulating the objective")
@master println("warning: flat fitness, consider reformulating the objective")
end
# Stop conditions:
# MaxIter: the maximal number of iterations in each run of CMA-ES
Expand Down Expand Up @@ -226,14 +232,14 @@ function terminate(opt::CMAESOpt)
# Benchmarking a BI-Population CMA-ES on the BBOB-2009 Function Testbed
termination = false
for (k, v) in condition
v && printstyled("Termination Condition Satisfied: ", k, '\n', color = :red)
@master v && printstyled("Termination Condition Satisfied: ", k, '\n', color = :red)
termination = termination | v
end
return termination
end

function restart(opt::CMAESOpt)
@printf("restarting...\n")
@master @printf("restarting...\n")
optnew = CMAESOpt(opt.f, sample(opt.lo, opt.hi), opt.σ0, opt.lo, opt.hi; => 2opt.λ)
optnew.xmin, optnew.fmin = opt.xmin, opt.fmin
return optnew
Expand All @@ -242,11 +248,11 @@ end
function trace_state(opt::CMAESOpt, iter, fcount)
elapsed_time = time() - opt.last_report_time
# display some information every iteration
@printf("time:%s iter:%d elapsed-time:%.2f ", Time(now()), iter, elapsed_time)
@printf("pmap-time:%.2f grad-time:%.2f ls-time:%.2f ls-dec:%2.2e\n",
@master @printf("time:%s iter:%d elapsed-time:%.2f ", Time(now()), iter, elapsed_time)
@master @printf("pmap-time:%.2f grad-time:%.2f ls-time:%.2f ls-dec:%2.2e\n",
opt.pmap_time, opt.grad_time, opt.ls_time, opt.ls_dec)
@printf("fcount:%d fval:%2.2e fmin:%2.2e ", fcount, opt.arfitness[1], opt.fmin)
@printf("norm:%2.2e penalty:%2.2e axis-ratio:%2.2e free-mem:%.2fGB\n",
@master @printf("fcount:%d fval:%2.2e fmin:%2.2e ", fcount, opt.arfitness[1], opt.fmin)
@master @printf("norm:%2.2e penalty:%2.2e axis-ratio:%2.2e free-mem:%.2fGB\n",
norm(opt.arx[:, opt.arindex[1]]), opt.arpenalty[opt.arindex[1]],
maximum(opt.D) / minimum(opt.D), Sys.free_memory() / 1024^3)
opt.last_report_time = time()
Expand All @@ -273,10 +279,11 @@ function save(opt::CMAESOpt)
BSON.bson(opt.file, data)
end

function minimize(fg, x0, args...; maxfevals = 0, gcitr = false,
maxiter = 0, resume = "false", cb = [], kwargs...)
function minimize(fg, x0, a...; maxfevals = 0, gcitr = false, maxiter = 0,
resume = "false", cb = [], seed = 1234, ka...)
Random.seed!(seed)
f, g = fg isa Tuple ? fg : (fg, zero)
opt = CMAESOpt(f, g, x0, args...; kwargs...)
opt = CMAESOpt(f, g, x0, a...; ka...)
cb = runall([throttle(x -> save(opt), 60); cb])
maxfevals = (maxfevals == 0) ? 1e3 * length(x0)^2 : maxfevals
maxfevals = maxiter != 0 ? maxiter * opt.λ : maxfevals
Expand Down
28 changes: 28 additions & 0 deletions src/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
function part(x::AbstractArray{T, N}, dim = -2) where {T, N}
!MPI.Initialized() && return x
dim = clamp(dim > 0 ? dim : N + dim + 1, 1, N)
dsize, rank, wsize = size(x, dim), myrank(), worldsize()
@assert wsize <= dsize
q, r = divrem(dsize, wsize)
splits = cumsum([i <= r ? q + 1 : q for i in 1:wsize])
pushfirst!(splits, 0)
is = (splits[rank + 1] + 1):splits[rank + 2]
view(x, ntuple(x -> x == dim ? is : (:), N)...)
end

function allgather(x, dim = 1)
if MPI.Initialized()
x = isa(x, Number) ? [x] : x
counts = MPI.Allgather(Cint(length(x)), MPI.COMM_WORLD)
recvbuf = MPI.Allgatherv(vec(x), counts, MPI.COMM_WORLD)
ranges = zip(cumsum([1; counts[1:end - 1]]), cumsum(counts))
shape = ntuple(i -> i == dim ? (:) : size(x, i), ndims(x))
xs = [reshape(view(recvbuf, i:j), shape) for (i, j) in ranges]
@assert sum(length, xs) == sum(counts)
return cat(xs..., dims = dim)
else
return x
end
end

myrank() = MPI.Initialized() ? MPI.Comm_rank(MPI.COMM_WORLD) : 0
18 changes: 17 additions & 1 deletion src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,20 @@ function throttle(f, timeout; leading = true)
end
return result
end
end
end

function pmap(f, xs)
if @isdefined(MPI) && nworkers() == 1
allgather(map(f, part(xs)))
else
Distributed.pmap(f, xs)
end
end

macro master(ex)
:(if !@isdefined(MPI) || myrank() == 0
$(esc(ex))
end)
end

worldsize() = @isdefined(MPI) && nworkers() == 1 ? MPI.Comm_size(MPI.COMM_WORLD) : nworkers()
23 changes: 23 additions & 0 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using GCMAES
using ForwardDiff
using MPI
using Test

MPI.Init()

rastrigin(x) = 10length(x) + sum(x.^2 .- 10 .* cos.(2π .* x))
∇rastrigin(x) = ForwardDiff.gradient(rastrigin, x)

D = 2000
x0 = fill(0.3, D)
σ0 = 0.2
lo = fill(-5.12, D)
hi = fill(5.12, D)

GCMAES.minimize(rastrigin, x0, σ0, lo, hi, maxiter = 200)

GCMAES.minimize((rastrigin, ∇rastrigin), x0, σ0, lo, hi, maxiter = 200)

rm("CMAES.bson", force = true)

MPI.Finalize()

2 comments on commit d854b8c

@AStupidBear
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/8149

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" d854b8c912d19dcce33cee949ee4c75f517641d9
git push origin v0.1.1

Please sign in to comment.