Skip to content

Commit

Permalink
Define CPU<->GPU adaptations
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Dec 23, 2024
1 parent 9c6d01c commit 47df9ce
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 3 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ include(joinpath("cuda", "operators_thomas_algorithm.jl"))
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
include(joinpath("cuda", "operators_spectral_element.jl"))
include(joinpath("cuda", "adapt.jl"))

end
83 changes: 83 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import Adapt
import ClimaCore.DataLayouts: ToCUDA, ToCPU

function Adapt.adapt(to::ToCUDA, space::Spaces.FiniteDifferenceSpace)
Spaces.FiniteDifferenceSpace(
Adapt.adapt(CUDA.CuArray, Spaces.grid(space)),
space.staggering,
)
end

function Adapt.adapt(to::ToCUDA, space::Spaces.SpectralElementSpace1D)
Spaces.SpectralElementSpace1D(Adapt.adapt(CUDA.CuArray, Spaces.grid(space)))
end

function Adapt.adapt(to::ToCUDA, space::Spaces.SpectralElementSpace2D)
Spaces.SpectralElementSpace2D(Adapt.adapt(CUDA.CuArray, Spaces.grid(space)))
end

function Adapt.adapt(to::ToCUDA, space::Spaces.SpectralElementSpaceSlab)
Spaces.SpectralElementSpaceSlab(
space.quadrature_style,
Adapt.adapt(CUDA.CuArray, space.local_geometry),
)
end

function Adapt.adapt(::ToCUDA, context::ClimaComms.AbstractCommsContext)
return context(ClimaComms.CUDADevice())
end

function Adapt.adapt(to::ToCUDA, space::Grids.LevelGrid)
return Grids.LevelGrid(
Adapt.adapt(CUDA.CuArray, space.full_grid),
Adapt.adapt(CUDA.CuArray, space.level),
)
end

function Adapt.adapt(to::ToCUDA, space::Grids.SpectralElementGrid1D)
return Grids.SpectralElementGrid1D(
Adapt.adapt(CUDA.CuArray, space.topology),
Adapt.adapt(CUDA.CuArray, space.quadrature_style),
Adapt.adapt(CUDA.CuArray, space.global_geometry),
Adapt.adapt(CUDA.CuArray, space.local_geometry),
Adapt.adapt(CUDA.CuArray, space.dss_weights),
)
end

function Adapt.adapt(::ToCUDA, space::Grids.SpectralElementGrid2D)
return Grids.SpectralElementGrid2D(
Adapt.adapt(CUDA.CuArray, space.topology),
Adapt.adapt(CUDA.CuArray, space.quadrature_style),
Adapt.adapt(CUDA.CuArray, space.global_geometry),
Adapt.adapt(CUDA.CuArray, space.local_geometry),
Adapt.adapt(CUDA.CuArray, space.local_dss_weights),
Adapt.adapt(CUDA.CuArray, space.internal_surface_geometry),
Adapt.adapt(CUDA.CuArray, space.boundary_surface_geometries),
space.enable_bubble,
)
end

function Adapt.adapt(to::ToCUDA, grid::Grids.FiniteDifferenceGrid)
return Grids.FiniteDifferenceGrid(
Adapt.adapt(CUDA.CuArray, grid.topology),
Adapt.adapt(CUDA.CuArray, grid.global_geometry),
Adapt.adapt(CUDA.CuArray, grid.center_local_geometry),
Adapt.adapt(CUDA.CuArray, grid.face_local_geometry),
)
end
function Adapt.adapt(to::ToCUDA, grid::Grids.ExtrudedFiniteDifferenceGrid)
return Grids.ExtrudedFiniteDifferenceGrid(
Adapt.adapt(CUDA.CuArray, grid.horizontal_grid),
Adapt.adapt(CUDA.CuArray, grid.vertical_grid),
Adapt.adapt(CUDA.CuArray, grid.hypsography),
Adapt.adapt(CUDA.CuArray, grid.global_geometry),
Adapt.adapt(CUDA.CuArray, grid.center_local_geometry),
Adapt.adapt(CUDA.CuArray, grid.face_local_geometry),
)
end

function Adapt.adapt(to::ToCUDA, data::DataLayouts.AbstractData)
DataLayouts.union_all(DataLayouts.singleton(data))(
Adapt.adapt(CUDA.CuArray, parent(data)),
)
end
4 changes: 4 additions & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,8 @@ include("CommonSpaces/CommonSpaces.jl")

include("deprecated.jl")

function Adapt.adapt(::ToCPU, context::ClimaComms.AbstractCommsContext)
return context(ClimaComms.CPUSingleThreaded())
end

end # module
4 changes: 4 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2229,4 +2229,8 @@ include("mapreduce.jl")

include("struct_linear_indexing.jl")

function Adapt.adapt(to::ToCPU, data::AbstractData)
union_all(singleton(data))(Adapt.adapt(Array, parent(data)))
end

end # module
2 changes: 2 additions & 0 deletions src/Domains/Domains.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module Domains

import ..DataLayouts: ToCPU, ToCUDA
import Adapt
import ..Geometry: Geometry, float_type
import IntervalSets
export RectangleDomain
Expand Down
2 changes: 2 additions & 0 deletions src/Grids/Grids.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module Grids

import ..DataLayouts: ToCPU, ToCUDA
import Adapt
import ClimaComms, Adapt, ForwardDiff, LinearAlgebra
import LinearAlgebra: det, norm
import ..DataLayouts: slab_index, vindex
Expand Down
11 changes: 11 additions & 0 deletions src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ mutable struct ExtrudedFiniteDifferenceGrid{
face_local_geometry::FLG
end

function Adapt.adapt(to::ToCPU, grid::ExtrudedFiniteDifferenceGrid)
return ExtrudedFiniteDifferenceGrid(
Adapt.adapt(Array, grid.horizontal_grid),
Adapt.adapt(Array, grid.vertical_grid),
Adapt.adapt(Array, grid.hypsography),
Adapt.adapt(Array, grid.global_geometry),
Adapt.adapt(Array, grid.center_local_geometry),
Adapt.adapt(Array, grid.face_local_geometry),
)
end

local_geometry_type(
::Type{ExtrudedFiniteDifferenceGrid{H, V, A, GG, CLG, FLG}},
) where {H, V, A, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts
Expand Down
8 changes: 8 additions & 0 deletions src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ mutable struct FiniteDifferenceGrid{
face_local_geometry::FLG
end

function Adapt.adapt(to::ToCPU, grid::FiniteDifferenceGrid)
return FiniteDifferenceGrid(
Adapt.adapt(Array, grid.topology),
Adapt.adapt(Array, grid.global_geometry),
Adapt.adapt(Array, grid.center_local_geometry),
Adapt.adapt(Array, grid.face_local_geometry),
)
end

function FiniteDifferenceGrid(topology::Topologies.IntervalTopology)
get!(Cache.OBJECT_CACHE, (FiniteDifferenceGrid, topology)) do
Expand Down
7 changes: 7 additions & 0 deletions src/Grids/level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ struct LevelGrid{
level::L
end

function Adapt.adapt(to::ToCPU, grid::LevelGrid)
return LevelGrid(
Adapt.adapt(Array, grid.full_grid),
Adapt.adapt(Array, grid.level),
)
end

quadrature_style(levelgrid::LevelGrid) =
quadrature_style(levelgrid.full_grid.horizontal_grid)

Expand Down
23 changes: 23 additions & 0 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ mutable struct SpectralElementGrid1D{
dss_weights::D
end

function Adapt.adapt(to::ToCPU, space::SpectralElementGrid1D)
return SpectralElementGrid1D(
Adapt.adapt(Array, space.topology),
Adapt.adapt(Array, space.quadrature_style),
Adapt.adapt(Array, space.global_geometry),
Adapt.adapt(Array, space.local_geometry),
Adapt.adapt(Array, space.dss_weights),
)
end

local_geometry_type(
::Type{SpectralElementGrid1D{T, Q, GG, LG}},
) where {T, Q, GG, LG} = eltype(LG) # calls eltype from DataLayouts
Expand Down Expand Up @@ -140,6 +150,19 @@ mutable struct SpectralElementGrid2D{
enable_bubble::Bool
end

function Adapt.adapt(::ToCPU, grid::SpectralElementGrid2D)
return SpectralElementGrid2D(
Adapt.adapt(Array, grid.topology),
Adapt.adapt(Array, grid.quadrature_style),
Adapt.adapt(Array, grid.global_geometry),
Adapt.adapt(Array, grid.local_geometry),
Adapt.adapt(Array, grid.local_dss_weights),
Adapt.adapt(Array, grid.internal_surface_geometry),
Adapt.adapt(Array, grid.boundary_surface_geometries),
grid.enable_bubble,
)
end

local_geometry_type(
::Type{SpectralElementGrid2D{T, Q, GG, LG, D, IS, BS}},
) where {T, Q, GG, LG, D, IS, BS} = eltype(LG) # calls eltype from DataLayouts
Expand Down
2 changes: 2 additions & 0 deletions src/Meshes/Meshes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module Meshes

import RootSolvers

import ..DataLayouts: ToCPU, ToCUDA
import Adapt
export RectilinearMesh,
EquiangularCubedSphere,
EquidistantCubedSphere,
Expand Down
1 change: 1 addition & 0 deletions src/Spaces/Spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Adapt

import ..slab, ..column, ..level
import ..Utilities: PlusHalf, half
import ..DataLayouts: ToCPU, ToCUDA
import ..DataLayouts,
..Geometry, ..Domains, ..Meshes, ..Topologies, ..Grids, ..Quadratures

Expand Down
7 changes: 7 additions & 0 deletions src/Spaces/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ struct ExtrudedFiniteDifferenceSpace{
staggering::S
end

function Adapt.adapt(to::ToCPU, space::ExtrudedFiniteDifferenceSpace)
return ExtrudedFiniteDifferenceSpace(
Adapt.adapt(Array, grid(space)),
space.staggering,
)
end

local_geometry_type(::Type{ExtrudedFiniteDifferenceSpace{G, S}}) where {G, S} =
local_geometry_type(G)

Expand Down
7 changes: 7 additions & 0 deletions src/Spaces/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ FiniteDifferenceSpace(
staggering::Staggering,
) = FiniteDifferenceSpace(Grids.FiniteDifferenceGrid(topology), staggering)

function Adapt.adapt(to::ToCPU, space::FiniteDifferenceSpace)
return FiniteDifferenceSpace(
Adapt.adapt(Array, grid(space)),
space.staggering,
)
end

local_geometry_type(::Type{FiniteDifferenceSpace{G, S}}) where {G, S} =
local_geometry_type(G)

Expand Down
7 changes: 7 additions & 0 deletions src/Spaces/pointspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ struct PointSpace{
local_geometry::LG
end

function Adapt.adapt(to::ToCPU, space::PointSpace)
return PointSpace(
Adapt.adapt(to, context),
Adapt.adapt(Array, space.local_geometry),
)
end

local_geometry_type(::Type{PointSpace{C, LG}}) where {C, LG} = eltype(LG) # calls eltype from DataLayouts

ClimaComms.device(space::PointSpace) = ClimaComms.device(space.context)
Expand Down
2 changes: 2 additions & 0 deletions src/Topologies/Topologies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module Topologies

import ClimaComms, Adapt

import ..DataLayouts: ToCPU, ToCUDA
import Adapt
import ..ClimaCore
import ..Utilities: Cache, cart_ind, linear_ind
import ..Geometry
Expand Down
5 changes: 2 additions & 3 deletions test/Spaces/unit_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ using Adapt
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = Grids.CellCenter()
staggering = Grids.CellCenter(),
)
cpu_space_out = Adapt.adapt(ToCPU(), cpu_space_in)
@test cpu_space_in === cpu_space_out
Expand All @@ -353,14 +353,13 @@ if ClimaComms.device() isa ClimaComms.CUDADevice
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = Grids.CellCenter()
staggering = Grids.CellCenter(),
)
cpu_space_out = Adapt.adapt(ToCPU(), cpu_space_in)
@test cpu_space_in === cpu_space_out
gpu_space_out = Adapt.adapt(ToCUDA(), cpu_space_in)
gpu_array_type = ClimaComms.array_type(ClimaComms.CUDADevice())
@test parent(Spaces.coordinates_data(space)) isa gpu_array_type

end
end

Expand Down

0 comments on commit 47df9ce

Please sign in to comment.