Skip to content

Commit

Permalink
Refactor macros, bump patch version
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 30, 2024
1 parent a22e4ac commit 0a8edda
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
agents:
queue: new-central
slurm_mem: 8G
modules: climacommon/2024_03_18
modules: climacommon/2024_05_27

env:
OPENBLAS_NUM_THREADS: 1
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
ClimaComms.jl Release Notes
========================

v0.6.1
- Macros have been refactored which fix some issues with code loading.

v0.6.0
-------

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.0"
version = "0.6.1"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ClimaComms

```@docs
ClimaComms.import_required_backends
ClimaComms.cuda_is_required
ClimaComms.mpi_is_required
```

## Devices
Expand Down
5 changes: 5 additions & 0 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@ end

ClimaComms.array_type(::ClimaComms.CUDADevice) = CUDA.CuArray

# Extending ClimaComms methods that operate on expressions (cannot use dispatch here)
ClimaComms.cuda_sync(expr) = CUDA.@sync expr
ClimaComms.cuda_time(expr) = CUDA.@time expr
ClimaComms.cuda_elasped(expr) = CUDA.@elapsed expr

end
5 changes: 0 additions & 5 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment varia
to either `MPI` or `SINGLETON`.
"""
function context(device = device(); target_context = context_type())
if target_context == :MPICommsContext && mpi_ext_is_not_loaded()
error(
"Loading MPI.jl is required to use MPICommsContext. You might want to call ClimaComms.@import_required_backends",
)
end
ContextConstructor = getproperty(ClimaComms, target_context)
return ContextConstructor(device)
end
Expand Down
117 changes: 43 additions & 74 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ The default is `CPU`.
"""
function device()
target_device = device_type()
if target_device == :CUDADevice && cuda_ext_is_not_loaded()
error(
"Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends",
)
end
DeviceConstructor = getproperty(ClimaComms, target_device)
return DeviceConstructor()
end
Expand Down Expand Up @@ -129,6 +124,8 @@ macro threaded(device, loop)
end
end

function cuda_time end

"""
@time device expr
Expand All @@ -145,24 +142,19 @@ CUDA.@time expr
for CUDA devices.
"""
macro time(device, expr)
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@time $expr
end
else
@assert $device isa $AbstractDevice
$Base.@time $(expr)
end
end,
)
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_time($expr)
else
@assert $device isa $AbstractDevice
$Base.@time $(expr)
end
end)
end

function cuda_elasped end

"""
@elapsed device expr
Expand All @@ -179,24 +171,19 @@ CUDA.@elapsed expr
for CUDA devices.
"""
macro elapsed(device, expr)
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@elapsed $expr
end
else
@assert $device isa $AbstractDevice
$Base.@elapsed $(expr)
end
end,
)
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_elasped($expr)
else
@assert $device isa $AbstractDevice
$Base.@elapsed $(expr)
end
end)
end

function cuda_sync end

"""
@sync device expr
Expand Down Expand Up @@ -233,26 +220,17 @@ to synchronize), then you may want to simply use [`@cuda_sync`](@ref).
"""
macro sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin
$(expr)
end
end
else
@assert $device isa $AbstractDevice
$Base.@sync begin
$(expr)
end
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$Base.@sync begin
$(expr)
end
end,
)
end
end)
end

"""
Expand All @@ -272,22 +250,13 @@ for CUDA devices.
"""
macro cuda_sync(device, expr)
# https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207
return esc(
quote
if $device isa $CUDADevice
@static if isnothing(
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt),
)
error("CUDA not loaded")
else
$Base.get_extension($ClimaComms, :ClimaCommsCUDAExt).CUDA.@sync begin
$(expr)
end
end
else
@assert $device isa $AbstractDevice
$(expr)
end
end,
)
__CC__ = ClimaComms
return esc(quote
if $device isa $CUDADevice
$(__CC__).cuda_sync($expr)
else
@assert $device isa $AbstractDevice
$(expr)
end
end)
end
40 changes: 25 additions & 15 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import ..ClimaComms

export import_required_backends
export @import_required_backends

function mpi_is_required()
return context_type() == :MPICommsContext
end
"""
mpi_is_required()
function mpi_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt))
end
Returns a Bool indicating if MPI should be loaded, based on the
`ENV["CLIMACOMMS_CONTEXT"]`. See [`ClimaComms.context`](@ref) for
more information.
function cuda_is_required()
return device_type() == :CUDADevice
end
```julia
mpi_is_required() && using MPI
```
"""
mpi_is_required() = context_type() == :MPICommsContext

function cuda_ext_is_not_loaded()
return isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
end
"""
cuda_is_required()
Returns a Bool indicating if CUDA should be loaded, based on the
`ENV["CLIMACOMMS_DEVICE"]`. See [`ClimaComms.device`](@ref) for
more information.
```julia
cuda_is_required() && using CUDA
```
"""
cuda_is_required() = device_type() == :CUDADevice

"""
ClimaComms.@import_required_backends
Expand All @@ -26,11 +36,11 @@ If the desired device is CUDA (as determined by `ClimaComms.device()`), try load
"""
macro import_required_backends()
return quote
@static if $mpi_is_required() && $mpi_ext_is_not_loaded()
@static if $mpi_is_required()
import MPI
@info "Loaded MPI.jl"
end
@static if $cuda_is_required() && $cuda_ext_is_not_loaded()
@static if $cuda_is_required()
import CUDA
@info "Loaded CUDA.jl"
end
Expand Down

0 comments on commit 0a8edda

Please sign in to comment.