Skip to content

Commit

Permalink
Add Base.summary(context) function
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Jan 10, 2025
1 parent 8616bc6 commit 92f8a87
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 5 deletions.
6 changes: 6 additions & 0 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ function ClimaComms._assign_device(::CUDADevice, rank_number)
return nothing
end

function ClimaComms.device_summary(::CUDADevice)
dev = CUDA.device()
uuid = CUDA.uuid(dev)
return "$dev ($uuid)"
end

function ClimaComms.device_functional(::CUDADevice)
return CUDA.functional()
end
Expand Down
46 changes: 41 additions & 5 deletions ext/ClimaCommsMPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ import ClimaComms
ClimaComms.MPICommsContext(device = ClimaComms.device()) =
ClimaComms.MPICommsContext(device, MPI.COMM_WORLD)

ClimaComms.local_communicator(ctx::ClimaComms.MPICommsContext) =
MPI.Comm_split_type(
ctx.mpicomm,
MPI.COMM_TYPE_SHARED,
MPI.Comm_rank(ctx.mpicomm),
)

function ClimaComms.init(ctx::ClimaComms.MPICommsContext)
if !MPI.Initialized()
MPI.Init()
Expand All @@ -18,11 +25,7 @@ function ClimaComms.init(ctx::ClimaComms.MPICommsContext)
)
end
# assign GPUs based on local rank
local_comm = MPI.Comm_split_type(
ctx.mpicomm,
MPI.COMM_TYPE_SHARED,
MPI.Comm_rank(ctx.mpicomm),
)
local_comm = ClimaComms.local_communicator(ctx)
ClimaComms._assign_device(ctx.device, MPI.Comm_rank(local_comm))
MPI.free(local_comm)
end
Expand Down Expand Up @@ -271,4 +274,37 @@ function ClimaComms.finish(
MPI.Waitall(ghost.send_reqs)
end

function Base.summary(io::IO, ctx::ClimaComms.MPICommsContext)
if !MPI.Initialized()
ClimaComms.iamroot(ctx) && @warn "MPI is not initialized."
return nothing
end
ClimaComms.barrier(ctx)

if ClimaComms.iamroot(ctx)
println(io, "Context: $(typeof(ctx).name.name)")
println(io, "Device: $(typeof(ctx.device))")
println(io, "Total Processes: $(ClimaComms.nprocs(ctx))")
end

ClimaComms.barrier(ctx)
rank = MPI.Comm_rank(ctx.mpicomm)
node_name = MPI.Get_processor_name()

if ctx.device isa ClimaComms.CUDADevice
local_comm = ClimaComms.local_communicator(ctx)
local_rank = MPI.Comm_rank(local_comm)
local_size = MPI.Comm_size(local_comm)
dev_summary = ClimaComms.device_summary(ctx.device)
println(
io,
"Rank: $rank, Local Rank: $local_rank, Node: $node_name, Device: $dev_summary",
)

MPI.free(local_comm)
else
println(io, "Rank: $rank, Node: $node_name")
end
end

end
9 changes: 9 additions & 0 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ function device()
return DeviceConstructor()
end

"""
device_summary(device)
Return the device type. If using a CUDADevice, return device information and UUID.
Internal function, used with `Base.summary(context)`
"""
device_summary(device) = device_type()

"""
ClimaComms.array_type(::AbstractDevice)
Expand Down
14 changes: 14 additions & 0 deletions src/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,17 @@ struct MPICommsContext{D <: AbstractDevice, C} <: AbstractCommsContext
end

function MPICommsContext end

"""
local_communicator(ctx::MPICommsContext)
Internal function to create a new MPI communicator for processes on the same physical node.
The communicator must be freed by `MPI.free`
```
local_comm = ClimaComms.local_communicator(ctx)
# use the communicator
MPI.free(local_comm)
```
"""
function local_communicator end
5 changes: 5 additions & 0 deletions src/singleton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ graph_context(ctx::SingletonCommsContext, args...) = SingletonGraphContext(ctx)
start(gctx::SingletonGraphContext) = nothing
progress(gctx::SingletonGraphContext) = nothing
finish(gctx::SingletonGraphContext) = nothing

function Base.summary(io::IO, ctx::SingletonCommsContext)
println(io, "Context: $(typeof(ctx).name.name)")
println(io, "Device: $(typeof(ctx.device))")
end
32 changes: 32 additions & 0 deletions test/summary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Test
import ClimaComms
ClimaComms.@import_required_backends

ctx = ClimaComms.context()
(pid, nprocs) = ClimaComms.init(ctx)

io = IOBuffer()
summary(io, ctx)
summary_str = String(take!(io))
print(summary_str)

@testset "ClimaComms Summary Tests" begin
# Use .name.name to get the unparameterized context type
if ClimaComms.iamroot(ctx)
@test contains(summary_str, string(typeof(ctx).name.name))
@test contains(summary_str, string(typeof(ctx.device).name.name))
end

if ctx isa ClimaComms.MPICommsContext
@testset "MPI Context Tests" begin

if ctx.device isa ClimaComms.CUDADevice
@test contains(summary_str, "CUDA.CuDevice")
end

ClimaComms.iamroot(ctx) &&
@test contains(summary_str, "Total Processes: $nprocs")
@test contains(summary_str, "Rank: $(pid-1)")
end
end
end

0 comments on commit 92f8a87

Please sign in to comment.