From 1c9ef7d797c2a684e4585a07fcea3cc2328937b3 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 9 Jan 2025 15:28:49 -0500 Subject: [PATCH] Define adapt for contexts and devices Update envs --- NEWS.md | 16 ++++++++++++++++ Project.toml | 13 ++++--------- docs/Manifest.toml | 33 ++++++++++++++++++++++++++++++--- docs/src/apis.md | 2 ++ ext/ClimaCommsCUDAExt.jl | 13 +++++++++++++ src/ClimaComms.jl | 1 + src/adapt.jl | 37 +++++++++++++++++++++++++++++++++++++ test/Project.toml | 1 + test/runtests.jl | 29 +++++++++++++++++++++++++++++ 9 files changed, 133 insertions(+), 12 deletions(-) create mode 100644 src/adapt.jl diff --git a/NEWS.md b/NEWS.md index 7347121..08c8ca1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,22 @@ ClimaComms.jl Release Notes ======================== +main +------- + +v0.6.5 +------- + +- Support for adapt was added, so that users can convert between CPU and GPU + devices, and contexts containing cpu and gpu devices [PR 103](https://github.com/CliMA/ClimaComms.jl/pull/103). + +- New MPI logging tools were added, `MPIFileLogger` and `MPILogger` [PR 100](https://github.com/CliMA/ClimaComms.jl/pull/100). + +v0.6.4 +------- + +- Add device-flexible `@assert` was added [PR 86](https://github.com/CliMA/ClimaComms.jl/pull/86). + v0.6.3 ------- diff --git a/Project.toml b/Project.toml index 6acebbe..db393a0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,10 @@ name = "ClimaComms" uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" -authors = [ - "Kiran Pamnany ", - "Simon Byrne ", - "Charles Kawczynski ", - "Sriharsha Kandala ", - "Jake Bolewski ", - "Gabriele Bozzola ", -] -version = "0.6.4" +authors = ["Kiran Pamnany ", "Simon Byrne ", "Charles Kawczynski ", "Sriharsha Kandala ", "Jake Bolewski ", "Gabriele Bozzola "] +version = "0.6.5" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -24,6 +18,7 @@ ClimaCommsMPIExt = "MPI" [compat] CUDA = "3, 4, 5" +Adapt = "3, 4" Logging = "1.9.4" LoggingExtras = "1.1.0" MPI = "0.20.18" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index ecab8f7..ab8296c 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.0" +julia_version = "1.11.2" manifest_format = "2.0" project_hash = "c5b9e727593a1bc35ccae9b71e346465d8a7803c" @@ -14,6 +14,18 @@ git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.1.1" + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + + [deps.Adapt.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.2" @@ -27,10 +39,10 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" version = "1.11.0" [[deps.ClimaComms]] -deps = ["Logging", "LoggingExtras"] +deps = ["Adapt", "Logging", "LoggingExtras"] path = ".." uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" -version = "0.6.4" +version = "0.6.5" [deps.ClimaComms.extensions] ClimaCommsCUDAExt = "CUDA" @@ -174,6 +186,11 @@ git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" version = "1.17.0+0" +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" + [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" version = "1.11.0" @@ -250,6 +267,11 @@ version = "2023.12.12" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.27+1" + [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" @@ -381,6 +403,11 @@ deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.13+1" +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.11.0+0" + [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" diff --git a/docs/src/apis.md b/docs/src/apis.md index 85bfa1c..cdc324d 100644 --- a/docs/src/apis.md +++ b/docs/src/apis.md @@ -34,6 +34,7 @@ ClimaComms.@elapsed ClimaComms.@assert ClimaComms.@sync ClimaComms.@cuda_sync +Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractDevice) ``` ## Contexts @@ -45,6 +46,7 @@ ClimaComms.MPICommsContext ClimaComms.AbstractGraphContext ClimaComms.context ClimaComms.graph_context +Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractCommsContext) ``` ## Context operations diff --git a/ext/ClimaCommsCUDAExt.jl b/ext/ClimaCommsCUDAExt.jl index 2d6fbfa..67355e1 100644 --- a/ext/ClimaCommsCUDAExt.jl +++ b/ext/ClimaCommsCUDAExt.jl @@ -2,6 +2,7 @@ module ClimaCommsCUDAExt import CUDA +import Adapt import ClimaComms import ClimaComms: CUDADevice @@ -14,6 +15,18 @@ function ClimaComms.device_functional(::CUDADevice) return CUDA.functional() end +function Adapt.adapt_structure( + to::Type{<:CUDA.CuArray}, + context::ClimaComms.AbstractCommsContext, +) + return context(adapt(to, ClimaComms.device(context))) +end + +Adapt.adapt_structure( + ::Type{<:CUDA.CuArray}, + device::ClimaComms.AbstractCPUDevice, +) = ClimaComms.CUDADevice() + ClimaComms.array_type(::CUDADevice) = CUDA.CuArray ClimaComms.allowscalar(f, ::CUDADevice, args...; kwargs...) = CUDA.@allowscalar f(args...; kwargs...) diff --git a/src/ClimaComms.jl b/src/ClimaComms.jl index 86916a4..f33ae18 100644 --- a/src/ClimaComms.jl +++ b/src/ClimaComms.jl @@ -16,6 +16,7 @@ include("context.jl") include("singleton.jl") include("mpi.jl") include("loading.jl") +include("adapt.jl") include("logging.jl") end # module diff --git a/src/adapt.jl b/src/adapt.jl new file mode 100644 index 0000000..b2ca2fb --- /dev/null +++ b/src/adapt.jl @@ -0,0 +1,37 @@ +import Adapt + +""" + Adapt.adapt_structure(::Type{<:AbstractArray}, context::AbstractCommsContext) + +Adapt a given context to a context with a device associated with the given array type. + +# Example + +```julia +Adapt.adapt_structure(Array, ClimaComms.context(ClimaComms.CUDADevice())) -> ClimaComms.CPUSingleThreaded() +``` + +!!! note + By default, adapting to `Array` creates a `CPUSingleThreaded` device, and + there is currently no way to conver to a CPUMultiThreaded device. +""" +Adapt.adapt_structure(to::Type{<:AbstractArray}, ctx::AbstractCommsContext) = + context(Adapt.adapt(to, device(ctx))) + +""" + Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice) + +Adapt a given device to a device associated with the given array type. + +# Example + +```julia +Adapt.adapt_structure(Array, ClimaComms.CUDADevice()) -> ClimaComms.CPUSingleThreaded() +``` + +!!! note + By default, adapting to `Array` creates a `CPUSingleThreaded` device, and + there is currently no way to conver to a CPUMultiThreaded device. +""" +Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice) = + CPUSingleThreaded() diff --git a/test/Project.toml b/test/Project.toml index 93c1bde..3e79521 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" diff --git a/test/runtests.jl b/test/runtests.jl index c5cf829..d6835fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -228,6 +228,35 @@ end @test x == Array(a)[1] end +import Adapt +@testset "Adapt" begin + @test Adapt.adapt(Array, ClimaComms.CUDADevice()) == + ClimaComms.CPUSingleThreaded() + @static if ClimaComms.device() isa ClimaComms.CUDADevice + @test Adapt.adapt(Array, ClimaComms.CUDADevice()) == + ClimaComms.CPUSingleThreaded() + @test Adapt.adapt(CUDA.CuArray, ClimaComms.CUDADevice()) == + ClimaComms.CUDADevice() + @test Adapt.adapt(CUDA.CuArray, ClimaComms.CPUSingleThreaded()) == + ClimaComms.CUDADevice() + end + + @test Adapt.adapt(Array, ClimaComms.context(ClimaComms.CUDADevice())) == + ClimaComms.context(ClimaComms.CPUSingleThreaded()) + @static if ClimaComms.device() isa ClimaComms.CUDADevice + @test Adapt.adapt(Array, ClimaComms.context(ClimaComms.CUDADevice())) == + ClimaComms.context(ClimaComms.CPUSingleThreaded()) + @test Adapt.adapt( + CUDA.CuArray, + ClimaComms.context(ClimaComms.CUDADevice()), + ) == ClimaComms.context(ClimaComms.CUDADevice()) + @test Adapt.adapt( + CUDA.CuArray, + ClimaComms.context(ClimaComms.CPUSingleThreaded()), + ) == ClimaComms.context(ClimaComms.CUDADevice()) + end +end + @testset "logging" begin include("logging.jl") end