From 92f8a871513e0b5ff8b338fac7a12c9648d3b534 Mon Sep 17 00:00:00 2001 From: nefrathenrici Date: Fri, 3 Jan 2025 09:01:50 -0800 Subject: [PATCH] Add `Base.summary(context)` function --- ext/ClimaCommsCUDAExt.jl | 6 ++++++ ext/ClimaCommsMPIExt.jl | 46 +++++++++++++++++++++++++++++++++++----- src/devices.jl | 9 ++++++++ src/mpi.jl | 14 ++++++++++++ src/singleton.jl | 5 +++++ test/summary.jl | 32 ++++++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 test/summary.jl diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl index bc236a7..06c557f 100644 --- a/ext/ClimaCommsCUDAExt.jl +++ b/ext/ClimaCommsCUDAExt.jl @@ -11,6 +11,12 @@ function ClimaComms._assign_device(::CUDADevice, rank_number) return nothing end +function ClimaComms.device_summary(::CUDADevice) + dev = CUDA.device() + uuid = CUDA.uuid(dev) + return "$dev ($uuid)" +end + function ClimaComms.device_functional(::CUDADevice) return CUDA.functional() end diff --git a/ext/ClimaCommsMPIExt.jl b/ext/ClimaCommsMPIExt.jl index d04732b..fd7907e 100644 --- a/ext/ClimaCommsMPIExt.jl +++ b/ext/ClimaCommsMPIExt.jl @@ -6,6 +6,13 @@ import ClimaComms ClimaComms.MPICommsContext(device = ClimaComms.device()) = ClimaComms.MPICommsContext(device, MPI.COMM_WORLD) +ClimaComms.local_communicator(ctx::ClimaComms.MPICommsContext) = + MPI.Comm_split_type( + ctx.mpicomm, + MPI.COMM_TYPE_SHARED, + MPI.Comm_rank(ctx.mpicomm), + ) + function ClimaComms.init(ctx::ClimaComms.MPICommsContext) if !MPI.Initialized() MPI.Init() @@ -18,11 +25,7 @@ function ClimaComms.init(ctx::ClimaComms.MPICommsContext) ) end # assign GPUs based on local rank - local_comm = MPI.Comm_split_type( - ctx.mpicomm, - MPI.COMM_TYPE_SHARED, - MPI.Comm_rank(ctx.mpicomm), - ) + local_comm = ClimaComms.local_communicator(ctx) ClimaComms._assign_device(ctx.device, MPI.Comm_rank(local_comm)) MPI.free(local_comm) end @@ -271,4 +274,37 @@ function ClimaComms.finish( MPI.Waitall(ghost.send_reqs) end +function Base.summary(io::IO, ctx::ClimaComms.MPICommsContext) + if !MPI.Initialized() + ClimaComms.iamroot(ctx) && @warn "MPI is not initialized." + return nothing + end + ClimaComms.barrier(ctx) + + if ClimaComms.iamroot(ctx) + println(io, "Context: $(typeof(ctx).name.name)") + println(io, "Device: $(typeof(ctx.device))") + println(io, "Total Processes: $(ClimaComms.nprocs(ctx))") + end + + ClimaComms.barrier(ctx) + rank = MPI.Comm_rank(ctx.mpicomm) + node_name = MPI.Get_processor_name() + + if ctx.device isa ClimaComms.CUDADevice + local_comm = ClimaComms.local_communicator(ctx) + local_rank = MPI.Comm_rank(local_comm) + local_size = MPI.Comm_size(local_comm) + dev_summary = ClimaComms.device_summary(ctx.device) + println( + io, + "Rank: $rank, Local Rank: $local_rank, Node: $node_name, Device: $dev_summary", + ) + + MPI.free(local_comm) + else + println(io, "Rank: $rank, Node: $node_name") + end +end + end diff --git a/src/devices.jl b/src/devices.jl index 7938df5..37d7a40 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -85,6 +85,15 @@ function device() return DeviceConstructor() end +""" + device_summary(device) + +Return the device type. If using a CUDADevice, return device information and UUID. + +Internal function, used with `Base.summary(context)` +""" +device_summary(device) = device_type() + """ ClimaComms.array_type(::AbstractDevice) diff --git a/src/mpi.jl b/src/mpi.jl index 524cc2c..07f11e7 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -12,3 +12,17 @@ struct MPICommsContext{D <: AbstractDevice, C} <: AbstractCommsContext end function MPICommsContext end + +""" + local_communicator(ctx::MPICommsContext) + +Internal function to create a new MPI communicator for processes on the same physical node. + +The communicator must be freed by `MPI.free` +``` +local_comm = ClimaComms.local_communicator(ctx) +# use the communicator +MPI.free(local_comm) +``` +""" +function local_communicator end diff --git a/src/singleton.jl b/src/singleton.jl index 82e3f1d..5ff68c8 100644 --- a/src/singleton.jl +++ b/src/singleton.jl @@ -49,3 +49,8 @@ graph_context(ctx::SingletonCommsContext, args...) = SingletonGraphContext(ctx) start(gctx::SingletonGraphContext) = nothing progress(gctx::SingletonGraphContext) = nothing finish(gctx::SingletonGraphContext) = nothing + +function Base.summary(io::IO, ctx::SingletonCommsContext) + println(io, "Context: $(typeof(ctx).name.name)") + println(io, "Device: $(typeof(ctx.device))") +end diff --git a/test/summary.jl b/test/summary.jl new file mode 100644 index 0000000..fafd9f1 --- /dev/null +++ b/test/summary.jl @@ -0,0 +1,32 @@ +using Test +import ClimaComms +ClimaComms.@import_required_backends + +ctx = ClimaComms.context() +(pid, nprocs) = ClimaComms.init(ctx) + +io = IOBuffer() +summary(io, ctx) +summary_str = String(take!(io)) +print(summary_str) + +@testset "ClimaComms Summary Tests" begin + # Use .name.name to get the unparameterized context type + if ClimaComms.iamroot(ctx) + @test contains(summary_str, string(typeof(ctx).name.name)) + @test contains(summary_str, string(typeof(ctx.device).name.name)) + end + + if ctx isa ClimaComms.MPICommsContext + @testset "MPI Context Tests" begin + + if ctx.device isa ClimaComms.CUDADevice + @test contains(summary_str, "CUDA.CuDevice") + end + + ClimaComms.iamroot(ctx) && + @test contains(summary_str, "Total Processes: $nprocs") + @test contains(summary_str, "Rank: $(pid-1)") + end + end +end