Skip to content

Commit

Permalink
Define CPU<->GPU adaptations
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski authored and Charlie Kawczynski committed Jan 9, 2025
1 parent 0371101 commit 4ec39fa
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 3 deletions.
8 changes: 8 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
topology::Topologies.IntervalTopology,
) = Topologies.DeviceIntervalTopology(topology.boundaries)

function Adapt.adapt(to::Type{CUDA.CuArray}, context::ClimaComms.AbstractCommsContext)
return context(adapt(to, ClimaComms.device(context)))
end

function Adapt.adapt(::Type{CUDA.CuArray}, device::ClimaComms.AbstractCPUDevice)
return ClimaComms.CUDADevice()
end
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ include("CommonGrids/CommonGrids.jl")
include("CommonSpaces/CommonSpaces.jl")

include("deprecated.jl")
include("adapt.jl")

end # module
2 changes: 2 additions & 0 deletions src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ mutable struct ExtrudedFiniteDifferenceGrid{
face_local_geometry::FLG
end

Adapt.@adapt_structure ExtrudedFiniteDifferenceGrid

local_geometry_type(
::Type{ExtrudedFiniteDifferenceGrid{H, V, A, GG, CLG, FLG}},
) where {H, V, A, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts
Expand Down
2 changes: 1 addition & 1 deletion src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ mutable struct FiniteDifferenceGrid{
center_local_geometry::CLG
face_local_geometry::FLG
end

Adapt.@adapt_structure FiniteDifferenceGrid

function FiniteDifferenceGrid(topology::Topologies.IntervalTopology)
get!(Cache.OBJECT_CACHE, (FiniteDifferenceGrid, topology)) do
Expand Down
4 changes: 4 additions & 0 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mutable struct SpectralElementGrid1D{
dss_weights::D
end

Adapt.@adapt_structure SpectralElementGrid1D

local_geometry_type(
::Type{SpectralElementGrid1D{T, Q, GG, LG}},
) where {T, Q, GG, LG} = eltype(LG) # calls eltype from DataLayouts
Expand Down Expand Up @@ -140,6 +142,8 @@ mutable struct SpectralElementGrid2D{
enable_bubble::Bool
end

Adapt.@adapt_structure SpectralElementGrid2D

local_geometry_type(
::Type{SpectralElementGrid2D{T, Q, GG, LG, D, IS, BS}},
) where {T, Q, GG, LG, D, IS, BS} = eltype(LG) # calls eltype from DataLayouts
Expand Down
10 changes: 10 additions & 0 deletions src/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Adapt
import ClimaComms

function Adapt.adapt(to::Type{Array}, context::ClimaComms.AbstractCommsContext)
return context(adapt(to, ClimaComms.device(context)))
end

function Adapt.adapt(::Type{Array}, device::ClimaComms.AbstractCPUDevice)
return ClimaComms.CPUSingleThreaded()
end
115 changes: 114 additions & 1 deletion test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#=
julia --check-bounds=yes --project
julia --project
julia --project=.buildkite
using Revise; include(joinpath("test", "Fields", "unit_field.jl"))
=#
using Test
Expand Down Expand Up @@ -702,6 +702,119 @@ end
nothing
end

using ClimaCore.CommonSpaces
using ClimaCore.Grids
using Adapt

function test_adapt(cpu_space_in)
test_adapt_space(cpu_space_in)
cpu_f_in = Fields.Field(Float64, cpu_space_in)
cpu_f_out = Adapt.adapt(Array, cpu_f_in)
@test parent(Spaces.local_geometry_data(axes(cpu_f_out))) isa Array
@test parent(Fields.field_values(cpu_f_out)) isa Array

@static if ClimaComms.device() isa ClimaComms.CUDADevice
# cpu -> gpu
gpu_f_out = Adapt.adapt(CUDA.CuArray, cpu_f_in)
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
# gpu -> gpu
cpu_f_out = Adapt.adapt(Array, gpu_f_out)
@test parent(Fields.field_values(cpu_f_out)) isa Array
end
end

function test_adapt_space(cpu_space_in)
# cpu -> cpu
cpu_space_out = Adapt.adapt(Array, cpu_space_in)
@test parent(Spaces.local_geometry_data(cpu_space_out)) isa Array

@static if ClimaComms.device() isa ClimaComms.CUDADevice
# cpu -> gpu
gpu_space_out = Adapt.adapt(CUDA.CuArray, cpu_space_in)
@test parent(Spaces.local_geometry_data(gpu_space_out)) isa CUDA.CuArray
# gpu -> gpu
cpu_space_out = Adapt.adapt(Array, gpu_space_out)
@test parent(Spaces.local_geometry_data(cpu_space_out)) isa Array
end
end

@testset "Test Adapt" begin
space = ExtrudedCubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = Grids.CellCenter(),
)
test_adapt(space)

space = CubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
radius = 10,
n_quad_points = 4,
h_elem = 10,
)
test_adapt(space)

space = ColumnSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
z_min = 0,
z_max = 10,
staggering = CellCenter()
)
test_adapt(space)

space = Box3DSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
z_min = 0,
z_max = 10,
periodic_x = false,
periodic_y = false,
n_quad_points = 4,
x_elem = 3,
y_elem = 4,
staggering = CellCenter()
)
test_adapt(space)

space = SliceXZSpace(;
device = ClimaComms.CPUSingleThreaded(),
z_elem = 10,
x_min = 0,
x_max = 1,
z_min = 0,
z_max = 1,
periodic_x = false,
n_quad_points = 4,
x_elem = 4,
staggering = CellCenter()
)
test_adapt(space)

space = RectangleXYSpace(;
device = ClimaComms.CPUSingleThreaded(),
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
periodic_x = false,
periodic_y = false,
n_quad_points = 4,
x_elem = 3,
y_elem = 4,
)
test_adapt(space)
end

@testset "Memoization of spaces" begin
space1 = spectral_space_2D()
space2 = spectral_space_2D()
Expand Down
1 change: 0 additions & 1 deletion test/Spaces/unit_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ end
end
end


#=
@testset "dss on 2×2 rectangular mesh (unstructured)" begin
FT = Float64
Expand Down

0 comments on commit 4ec39fa

Please sign in to comment.