Skip to content

Commit

Permalink
Define adapt for contexts and devices
Browse files Browse the repository at this point in the history
Update envs
  • Loading branch information
charleskawczynski committed Jan 9, 2025
1 parent 0b580b7 commit 1c9ef7d
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 12 deletions.
16 changes: 16 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
ClimaComms.jl Release Notes
========================

main
-------

v0.6.5
-------

- Support for adapt was added, so that users can convert between CPU and GPU
devices, and contexts containing cpu and gpu devices [PR 103](https://github.com/CliMA/ClimaComms.jl/pull/103).

- New MPI logging tools were added, `MPIFileLogger` and `MPILogger` [PR 100](https://github.com/CliMA/ClimaComms.jl/pull/100).

v0.6.4
-------

- Add device-flexible `@assert` was added [PR 86](https://github.com/CliMA/ClimaComms.jl/pull/86).

v0.6.3
-------

Expand Down
13 changes: 4 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
name = "ClimaComms"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
authors = [
"Kiran Pamnany <[email protected]>",
"Simon Byrne <[email protected]>",
"Charles Kawczynski <[email protected]>",
"Sriharsha Kandala <[email protected]>",
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.4"
authors = ["Kiran Pamnany <[email protected]>", "Simon Byrne <[email protected]>", "Charles Kawczynski <[email protected]>", "Sriharsha Kandala <[email protected]>", "Jake Bolewski <[email protected]>", "Gabriele Bozzola <[email protected]>"]
version = "0.6.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"

Expand All @@ -24,6 +18,7 @@ ClimaCommsMPIExt = "MPI"

[compat]
CUDA = "3, 4, 5"
Adapt = "3, 4"
Logging = "1.9.4"
LoggingExtras = "1.1.0"
MPI = "0.20.18"
Expand Down
33 changes: 30 additions & 3 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.11.0"
julia_version = "1.11.2"
manifest_format = "2.0"
project_hash = "c5b9e727593a1bc35ccae9b71e346465d8a7803c"

Expand All @@ -14,6 +14,18 @@ git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.1.1"

[deps.Adapt.extensions]
AdaptStaticArraysExt = "StaticArrays"

[deps.Adapt.weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.2"
Expand All @@ -27,10 +39,10 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
version = "1.11.0"

[[deps.ClimaComms]]
deps = ["Logging", "LoggingExtras"]
deps = ["Adapt", "Logging", "LoggingExtras"]
path = ".."
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.4"
version = "0.6.5"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand Down Expand Up @@ -174,6 +186,11 @@ git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175"
uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531"
version = "1.17.0+0"

[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
version = "1.11.0"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
version = "1.11.0"
Expand Down Expand Up @@ -250,6 +267,11 @@ version = "2023.12.12"
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.27+1"

[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
Expand Down Expand Up @@ -381,6 +403,11 @@ deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.11.0+0"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ClimaComms.@elapsed
ClimaComms.@assert
ClimaComms.@sync
ClimaComms.@cuda_sync
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractDevice)
```

## Contexts
Expand All @@ -45,6 +46,7 @@ ClimaComms.MPICommsContext
ClimaComms.AbstractGraphContext
ClimaComms.context
ClimaComms.graph_context
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractCommsContext)
```

## Context operations
Expand Down
13 changes: 13 additions & 0 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ClimaCommsCUDAExt

import CUDA

import Adapt
import ClimaComms
import ClimaComms: CUDADevice

Expand All @@ -14,6 +15,18 @@ function ClimaComms.device_functional(::CUDADevice)
return CUDA.functional()
end

function Adapt.adapt_structure(
to::Type{<:CUDA.CuArray},
context::ClimaComms.AbstractCommsContext,
)
return context(adapt(to, ClimaComms.device(context)))
end

Adapt.adapt_structure(
::Type{<:CUDA.CuArray},
device::ClimaComms.AbstractCPUDevice,
) = ClimaComms.CUDADevice()

ClimaComms.array_type(::CUDADevice) = CUDA.CuArray
ClimaComms.allowscalar(f, ::CUDADevice, args...; kwargs...) =
CUDA.@allowscalar f(args...; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/ClimaComms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("context.jl")
include("singleton.jl")
include("mpi.jl")
include("loading.jl")
include("adapt.jl")
include("logging.jl")

end # module
37 changes: 37 additions & 0 deletions src/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import Adapt

"""
Adapt.adapt_structure(::Type{<:AbstractArray}, context::AbstractCommsContext)
Adapt a given context to a context with a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.context(ClimaComms.CUDADevice())) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to conver to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(to::Type{<:AbstractArray}, ctx::AbstractCommsContext) =
context(Adapt.adapt(to, device(ctx)))

"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice)
Adapt a given device to a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.CUDADevice()) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to conver to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice) =
CPUSingleThreaded()
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Expand Down
29 changes: 29 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,35 @@ end
@test x == Array(a)[1]
end

import Adapt
@testset "Adapt" begin
@test Adapt.adapt(Array, ClimaComms.CUDADevice()) ==
ClimaComms.CPUSingleThreaded()
@static if ClimaComms.device() isa ClimaComms.CUDADevice
@test Adapt.adapt(Array, ClimaComms.CUDADevice()) ==
ClimaComms.CPUSingleThreaded()
@test Adapt.adapt(CUDA.CuArray, ClimaComms.CUDADevice()) ==
ClimaComms.CUDADevice()
@test Adapt.adapt(CUDA.CuArray, ClimaComms.CPUSingleThreaded()) ==
ClimaComms.CUDADevice()
end

@test Adapt.adapt(Array, ClimaComms.context(ClimaComms.CUDADevice())) ==
ClimaComms.context(ClimaComms.CPUSingleThreaded())
@static if ClimaComms.device() isa ClimaComms.CUDADevice
@test Adapt.adapt(Array, ClimaComms.context(ClimaComms.CUDADevice())) ==
ClimaComms.context(ClimaComms.CPUSingleThreaded())
@test Adapt.adapt(
CUDA.CuArray,
ClimaComms.context(ClimaComms.CUDADevice()),
) == ClimaComms.context(ClimaComms.CUDADevice())
@test Adapt.adapt(
CUDA.CuArray,
ClimaComms.context(ClimaComms.CPUSingleThreaded()),
) == ClimaComms.context(ClimaComms.CUDADevice())
end
end

@testset "logging" begin
include("logging.jl")
end

0 comments on commit 1c9ef7d

Please sign in to comment.