Skip to content

Commit

Permalink
Make the MPI context immutable
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 6, 2025
1 parent 06d94a5 commit e6060c2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.4"
version = "0.6.5"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
59 changes: 35 additions & 24 deletions ext/ClimaCommsMPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
module ClimaCommsMPIExt

import MPI
import ClimaComms
import ClimaComms: mpicomm

ClimaComms.MPICommsContext(device = ClimaComms.device()) =
ClimaComms.MPICommsContext(device, MPI.COMM_WORLD)
const CLIMA_COMM_WORLD = Ref{typeof(MPI.COMM_WORLD)}()

function set_mpicomm!()
CLIMA_COMM_WORLD[] = MPI.COMM_WORLD
return CLIMA_COMM_WORLD[]
end

ClimaComms.mpicomm(::ClimaComms.MPICommsContext) = CLIMA_COMM_WORLD[]

function ClimaComms.MPICommsContext(device = ClimaComms.device())
set_mpicomm!()
ClimaComms.MPICommsContext(device)
end

function ClimaComms.init(ctx::ClimaComms.MPICommsContext)
if !MPI.Initialized()
Expand All @@ -19,9 +30,9 @@ function ClimaComms.init(ctx::ClimaComms.MPICommsContext)
end
# assign GPUs based on local rank
local_comm = MPI.Comm_split_type(
ctx.mpicomm,
mpicomm(ctx),
MPI.COMM_TYPE_SHARED,
MPI.Comm_rank(ctx.mpicomm),
MPI.Comm_rank(mpicomm(ctx)),
)
ClimaComms._assign_device(ctx.device, MPI.Comm_rank(local_comm))
MPI.free(local_comm)
Expand All @@ -32,48 +43,48 @@ end
ClimaComms.device(ctx::ClimaComms.MPICommsContext) = ctx.device

ClimaComms.mypid(ctx::ClimaComms.MPICommsContext) =
MPI.Comm_rank(ctx.mpicomm) + 1
MPI.Comm_rank(mpicomm(ctx)) + 1
ClimaComms.iamroot(ctx::ClimaComms.MPICommsContext) = ClimaComms.mypid(ctx) == 1
ClimaComms.nprocs(ctx::ClimaComms.MPICommsContext) = MPI.Comm_size(ctx.mpicomm)
ClimaComms.nprocs(ctx::ClimaComms.MPICommsContext) = MPI.Comm_size(mpicomm(ctx))

ClimaComms.barrier(ctx::ClimaComms.MPICommsContext) = MPI.Barrier(ctx.mpicomm)
ClimaComms.barrier(ctx::ClimaComms.MPICommsContext) = MPI.Barrier(mpicomm(ctx))

ClimaComms.reduce(ctx::ClimaComms.MPICommsContext, val, op) =
MPI.Reduce(val, op, 0, ctx.mpicomm)
MPI.Reduce(val, op, 0, mpicomm(ctx))

ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) =
MPI.Reduce!(sendbuf, recvbuf, op, ctx.mpicomm; root = 0)
MPI.Reduce!(sendbuf, recvbuf, op, mpicomm(ctx); root = 0)

ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) =
MPI.Reduce!(sendrecvbuf, op, ctx.mpicomm; root = 0)
MPI.Reduce!(sendrecvbuf, op, mpicomm(ctx); root = 0)

ClimaComms.allreduce(ctx::ClimaComms.MPICommsContext, sendbuf, op) =
MPI.Allreduce(sendbuf, op, ctx.mpicomm)
MPI.Allreduce(sendbuf, op, mpicomm(ctx))

ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) =
MPI.Allreduce!(sendbuf, recvbuf, op, ctx.mpicomm)
MPI.Allreduce!(sendbuf, recvbuf, op, mpicomm(ctx))

ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) =
MPI.Allreduce!(sendrecvbuf, op, ctx.mpicomm)
MPI.Allreduce!(sendrecvbuf, op, mpicomm(ctx))

ClimaComms.bcast(ctx::ClimaComms.MPICommsContext, object) =
MPI.bcast(object, ctx.mpicomm; root = 0)
MPI.bcast(object, mpicomm(ctx); root = 0)

function ClimaComms.gather(ctx::ClimaComms.MPICommsContext, array)
dims = size(array)
lengths = MPI.Gather(dims[end], 0, ctx.mpicomm)
lengths = MPI.Gather(dims[end], 0, mpicomm(ctx))
if ClimaComms.iamroot(ctx)
dimsout = (dims[1:(end - 1)]..., sum(lengths))
arrayout = similar(array, dimsout)
recvbuf = MPI.VBuffer(arrayout, lengths .* prod(dims[1:(end - 1)]))
else
recvbuf = nothing
end
MPI.Gatherv!(array, recvbuf, 0, ctx.mpicomm)
MPI.Gatherv!(array, recvbuf, 0, mpicomm(ctx))
end

ClimaComms.abort(ctx::ClimaComms.MPICommsContext, status::Int) =
MPI.Abort(ctx.mpicomm, status)
MPI.Abort(mpicomm(ctx), status)


# We could probably do something fancier here?
Expand Down Expand Up @@ -171,7 +182,7 @@ function graph_context(
for n in 1:length(recv_bufs)
MPI.Recv_init(
recv_bufs[n],
ctx.mpicomm,
mpicomm(ctx),
recv_reqs[n];
source = recv_ranks[n],
tag = tag,
Expand All @@ -181,7 +192,7 @@ function graph_context(
for n in 1:length(send_bufs)
MPI.Send_init(
send_bufs[n],
ctx.mpicomm,
mpicomm(ctx),
send_reqs[n];
dest = send_ranks[n],
tag = tag,
Expand Down Expand Up @@ -226,7 +237,7 @@ function ClimaComms.start(
ghost.recv_bufs[n],
ghost.recv_ranks[n],
ghost.tag,
ghost.ctx.mpicomm,
mpicomm(ghost.ctx),
ghost.recv_reqs[n],
)
end
Expand All @@ -236,7 +247,7 @@ function ClimaComms.start(
ghost.send_bufs[n],
ghost.send_ranks[n],
ghost.tag,
ghost.ctx.mpicomm,
mpicomm(ghost.ctx),
ghost.send_reqs[n],
)
end
Expand All @@ -254,9 +265,9 @@ function ClimaComms.progress(
ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext},
)
if isdefined(MPI, :MPI_ANY_SOURCE) # < v0.20
MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm)
MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, mpicomm(ghost.ctx))
else # >= v0.20
MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm)
MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, mpicomm(ghost.ctx))
end
end

Expand Down
5 changes: 3 additions & 2 deletions src/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
A MPI communications context, used for distributed runs.
[`AbstractCPUDevice`](@ref) and [`CUDADevice`](@ref) device options are currently supported.
"""
struct MPICommsContext{D <: AbstractDevice, C} <: AbstractCommsContext
struct MPICommsContext{D <: AbstractDevice} <: AbstractCommsContext
device::D
mpicomm::C
end

function MPICommsContext end

function mpicomm end

0 comments on commit e6060c2

Please sign in to comment.