From 4142535c2c2ec0dd5d7bd6bc6843c34827721e21 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 2 Jan 2025 10:49:22 -0500 Subject: [PATCH] Fix adapt dispatch for KernelAdaptor --- ext/cuda/adapt.jl | 30 ++++++++++++++++++++++++++++++ src/Grids/extruded.jl | 9 --------- src/Grids/finitedifference.jl | 8 -------- src/Grids/spectralelement.jl | 7 ------- 4 files changed, 30 insertions(+), 24 deletions(-) diff --git a/ext/cuda/adapt.jl b/ext/cuda/adapt.jl index 6d3162ae64..1e366cf709 100644 --- a/ext/cuda/adapt.jl +++ b/ext/cuda/adapt.jl @@ -81,3 +81,33 @@ function Adapt.adapt(to::ToCUDA, data::DataLayouts.AbstractData) Adapt.adapt(CUDA.CuArray, parent(data)), ) end + +Adapt.adapt_structure( + to::CUDA.KernelAdaptor, + grid::Grids.ExtrudedFiniteDifferenceGrid, +) = Grids.DeviceExtrudedFiniteDifferenceGrid( + Adapt.adapt(to, Grids.vertical_topology(grid)), + Adapt.adapt(to, grid.horizontal_grid.quadrature_style), + Adapt.adapt(to, grid.global_geometry), + Adapt.adapt(to, grid.center_local_geometry), + Adapt.adapt(to, grid.face_local_geometry), +) + +Adapt.adapt_structure( + to::CUDA.KernelAdaptor, + grid::Grids.FiniteDifferenceGrid, +) = Grids.DeviceFiniteDifferenceGrid( + Adapt.adapt(to, grid.topology), + Adapt.adapt(to, grid.global_geometry), + Adapt.adapt(to, grid.center_local_geometry), + Adapt.adapt(to, grid.face_local_geometry), +) + +Adapt.adapt_structure( + to::CUDA.KernelAdaptor, + grid::Grids.SpectralElementGrid2D, +) = Grids.DeviceSpectralElementGrid2D( + Adapt.adapt(to, grid.quadrature_style), + Adapt.adapt(to, grid.global_geometry), + Adapt.adapt(to, grid.local_geometry), +) diff --git a/src/Grids/extruded.jl b/src/Grids/extruded.jl index 220ca72a7e..aa38df7983 100644 --- a/src/Grids/extruded.jl +++ b/src/Grids/extruded.jl @@ -166,15 +166,6 @@ local_geometry_type( ::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG}}, ) where {VT, Q, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts -Adapt.adapt_structure(to, grid::ExtrudedFiniteDifferenceGrid) = - DeviceExtrudedFiniteDifferenceGrid( - Adapt.adapt(to, vertical_topology(grid)), - Adapt.adapt(to, grid.horizontal_grid.quadrature_style), - Adapt.adapt(to, grid.global_geometry), - Adapt.adapt(to, grid.center_local_geometry), - Adapt.adapt(to, grid.face_local_geometry), - ) - quadrature_style(grid::DeviceExtrudedFiniteDifferenceGrid) = grid.quadrature_style vertical_topology(grid::DeviceExtrudedFiniteDifferenceGrid) = diff --git a/src/Grids/finitedifference.jl b/src/Grids/finitedifference.jl index 57534b3a99..c737b0b780 100644 --- a/src/Grids/finitedifference.jl +++ b/src/Grids/finitedifference.jl @@ -198,14 +198,6 @@ local_geometry_type( ::Type{DeviceFiniteDifferenceGrid{T, GG, CLG, FLG}}, ) where {T, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts -Adapt.adapt_structure(to, grid::FiniteDifferenceGrid) = - DeviceFiniteDifferenceGrid( - Adapt.adapt(to, grid.topology), - Adapt.adapt(to, grid.global_geometry), - Adapt.adapt(to, grid.center_local_geometry), - Adapt.adapt(to, grid.face_local_geometry), - ) - topology(grid::DeviceFiniteDifferenceGrid) = grid.topology vertical_topology(grid::DeviceFiniteDifferenceGrid) = grid.topology diff --git a/src/Grids/spectralelement.jl b/src/Grids/spectralelement.jl index 6c2aeaa66f..b61fbb90fc 100644 --- a/src/Grids/spectralelement.jl +++ b/src/Grids/spectralelement.jl @@ -620,13 +620,6 @@ end ClimaComms.context(grid::DeviceSpectralElementGrid2D) = DeviceSideContext() ClimaComms.device(grid::DeviceSpectralElementGrid2D) = DeviceSideDevice() -Adapt.adapt_structure(to, grid::SpectralElementGrid2D) = - DeviceSpectralElementGrid2D( - Adapt.adapt(to, grid.quadrature_style), - Adapt.adapt(to, grid.global_geometry), - Adapt.adapt(to, grid.local_geometry), - ) - ## aliases const RectilinearSpectralElementGrid2D = SpectralElementGrid2D{<:Topologies.RectilinearTopology2D}