Skip to content

Commit

Permalink
Define CPU<->GPU adaptations
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 9, 2025
1 parent 0371101 commit 5a42ce5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,14 @@ Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
topology::Topologies.IntervalTopology,
) = Topologies.DeviceIntervalTopology(topology.boundaries)

function Adapt.adapt(
to::Type{CUDA.CuArray},
context::ClimaComms.AbstractCommsContext,
)
return context(adapt(to, ClimaComms.device(context)))
end

function Adapt.adapt(::Type{CUDA.CuArray}, device::ClimaComms.AbstractCPUDevice)
return ClimaComms.CUDADevice()
end
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ include("CommonGrids/CommonGrids.jl")
include("CommonSpaces/CommonSpaces.jl")

include("deprecated.jl")
include("adapt.jl")

end # module
10 changes: 10 additions & 0 deletions src/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Adapt
import ClimaComms

function Adapt.adapt(to::Type{Array}, context::ClimaComms.AbstractCommsContext)
return context(adapt(to, ClimaComms.device(context)))
end

function Adapt.adapt(::Type{Array}, device::ClimaComms.AbstractCPUDevice)
return ClimaComms.CPUSingleThreaded()
end
39 changes: 39 additions & 0 deletions test/Spaces/unit_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,45 @@ end
end
end

using ClimaCore.CommonSpaces
using ClimaCore.Grids
using ClimaCore.DataLayouts: ToCPU, ToCUDA
using Adapt
@testset "Adapt between CPU<->CPU" begin
cpu_space_in = ExtrudedCubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = Grids.CellCenter(),
)
cpu_space_out = Adapt.adapt(Array, cpu_space_in)
@test cpu_space_in === cpu_space_out
end

if ClimaComms.device() isa ClimaComms.CUDADevice
@testset "Adapt between CPU->GPU->CPU" begin
cpu_space_in = ExtrudedCubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = Grids.CellCenter(),
)
cpu_space_out = Adapt.adapt(Array, cpu_space_in)
@test cpu_space_in === cpu_space_out
@static if ClimaComms.device() isa ClimaComms.CUDADevice
gpu_space_out = Adapt.adapt(CUDA.CuArray, cpu_space_in)
@test parent(Spaces.coordinates_data(space)) isa CUDA.CuArray
end
end
end

#=
@testset "dss on 2×2 rectangular mesh (unstructured)" begin
Expand Down

0 comments on commit 5a42ce5

Please sign in to comment.