Skip to content

Commit

Permalink
interpolate_array: add support for GPUs
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
Sbozzolo committed Nov 8, 2023
1 parent cf10bd9 commit 008b886
Show file tree
Hide file tree
Showing 5 changed files with 740 additions and 330 deletions.
13 changes: 13 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,25 @@ steps:
command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_CONTEXT: "MPI"
CLIMACOMMS_DEVICE: "CPU"
agents:
slurm_ntasks: 2

- label: "Unit: distributed remapping (1 process)"
key: distributed_remapping_1proc
command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_DEVICE: "CPU"

- label: "Unit: distributed remapping with CUDA"
key: distributed_remapping_gpu
command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_CONTEXT: "MPI"
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_ntasks: 2
slurm_gpus: 1

- label: "Unit: distributed gather"
key: unit_distributed_gather4
Expand Down
1 change: 1 addition & 0 deletions src/Remapping/Remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import ..DataLayouts,
..Hypsography

using ..RecursiveApply
using CUDA

include("interpolate_array.jl")
include("distributed_remapping.jl")
Expand Down
15 changes: 11 additions & 4 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,17 @@ function interpolate(

if length(remapper.target_zcoords) == 0
out_local_array = zeros(FT, size(remapper.local_target_hcoords_bitmask))
interpolated_values = [
interpolate_slab(field, Fields.SlabIndex(nothing, gidx), weights) for (gidx, weights) in
zip(remapper.local_indices, remapper.interpolation_coeffs)
]

interpolated_values = zeros(FT, length(remapper.local_indices))
slab_indices =
[Fields.SlabIndex(nothing, gidx) for gidx in remapper.local_indices]

interpolate_slab!(
interpolated_values,
field,
slab_indices,
remapper.interpolation_coeffs,
)

# out_local_array[remapper.local_target_hcoords_bitmask] returns a view on space we
# want to write on
Expand Down
Loading

0 comments on commit 008b886

Please sign in to comment.