Skip to content

Commit

Permalink
Pre-allocate memory for nodes in interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Oct 23, 2023
1 parent 3002686 commit c701b65
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
14 changes: 13 additions & 1 deletion src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,20 @@ function interpolate(

if length(remapper.target_zcoords) == 0
out_local_array = zeros(FT, size(remapper.local_target_hcoords_bitmask))

# We need to prepare the memory area where we save the node values
example_weights = remapper.interpolation_coeffs[1]
# We need to call `length.(weights)...` because `weights` is a 1- or 2- tuple of
# `SMatrix`s and we want `node_mem` to have the same shape
node_mem = zeros(FT, length.(example_weights)...)

interpolated_values = [
interpolate_slab(field, Fields.SlabIndex(nothing, gidx), weights) for (gidx, weights) in
interpolate_slab(
field,
Fields.SlabIndex(nothing, gidx),
weights;
node_mem,
) for (gidx, weights) in
zip(remapper.local_indices, remapper.interpolation_coeffs)
]

Expand Down
59 changes: 43 additions & 16 deletions src/Remapping/interpolate_array.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,48 @@
# node_mem is a pre-allocated area of memory where the values of the fields are written to.
# node_mem has to be `Nq` or `Nq` x `Nq` in size and the same eltype as `field`.
function interpolate_slab(
field::Fields.Field,
slabidx::Fields.SlabIndex,
(I1,)::Tuple{<:AbstractArray},
(I1,)::Tuple{<:AbstractArray};
node_mem,
)
space = axes(field)
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)

nodes = [
node_mem .= [
Operators.get_node(space, field, CartesianIndex((i,)), slabidx) for
i in 1:Nq
]

return I1 nodes
return I1 node_mem
end

function interpolate_slab(
field::Fields.Field,
slabidx::Fields.SlabIndex,
(I1, I2)::Tuple{<:AbstractArray, <:AbstractArray},
(I1, I2)::Tuple{<:AbstractArray, <:AbstractArray};
node_mem,
)
space = axes(field)
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)

nodes = [
node_mem .= [
Operators.get_node(space, field, CartesianIndex((i, j)), slabidx)
for i in 1:Nq, j in 1:Nq
]

x = (I1 * nodes) I2
return x
return (I1 * node_mem) I2
end

"""
interpolate_slab_level(
field::Fields.Field,
h::Integer,
Is::Tuple,
zcoord,
zcoord;
node_mem
)
Vertically interpolate the given `field` on `zcoord`.
Expand All @@ -50,12 +54,18 @@ element in a column, no interpolation is performed and the value at the cell cen
returned. Effectively, this means that the interpolation is first-order accurate across the
column, but zeroth-order accurate close to the boundaries.
- `node_mem` has to a pre-allocated area of memory where the values of the fields are
written to. `node_mem` has to be `Nq` or `Nq` x `Nq` in size and the same eltype as
`field`, where `Nq` is the number of degrees of freedom associated to the horizontal
space.
"""
function interpolate_slab_level(
field::Fields.Field,
h::Integer,
Is::Tuple,
zcoord,
zcoord;
node_mem,
)
space = axes(field)
vert_topology = Spaces.vertical_topology(space)
Expand Down Expand Up @@ -86,8 +96,8 @@ function interpolate_slab_level(
ξ3 = ξ3 - 1
end
end
f_lo = interpolate_slab(field, Fields.SlabIndex(v_lo, h), Is)
f_hi = interpolate_slab(field, Fields.SlabIndex(v_hi, h), Is)
f_lo = interpolate_slab(field, Fields.SlabIndex(v_lo, h), Is; node_mem)
f_hi = interpolate_slab(field, Fields.SlabIndex(v_hi, h), Is; node_mem)
return ((1 - ξ3) * f_lo + (1 + ξ3) * f_hi) / 2
end

Expand Down Expand Up @@ -126,9 +136,15 @@ function interpolate_array(
horz_mesh = horz_topology.mesh

T = eltype(field)

array = zeros(T, length(xpts), length(zpts))

FT = Spaces.undertype(space)
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)

node_mem = zeros(FT, Nq, Nq)

for (ix, xcoord) in enumerate(xpts)
hcoord = xcoord
helem = Meshes.containing_element(horz_mesh, hcoord)
Expand All @@ -138,7 +154,8 @@ function interpolate_array(
h = helem

for (iz, zcoord) in enumerate(zpts)
array[ix, iz] = interpolate_slab_level(field, h, weights, zcoord)
array[ix, iz] =
interpolate_slab_level(field, h, weights, zcoord; node_mem)
end
end
return array
Expand All @@ -160,6 +177,11 @@ function interpolate_array(
array = zeros(T, length(xpts), length(ypts), length(zpts))

FT = Spaces.undertype(space)
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)

node_mem = zeros(FT, Nq)

for (iy, ycoord) in enumerate(ypts), (ix, xcoord) in enumerate(xpts)
hcoord = Geometry.product_coordinates(xcoord, ycoord)
helem = Meshes.containing_element(horz_mesh, hcoord)
Expand All @@ -171,7 +193,7 @@ function interpolate_array(

for (iz, zcoord) in enumerate(zpts)
array[ix, iy, iz] =
interpolate_slab_level(field, h, weights, zcoord)
interpolate_slab_level(field, h, weights, zcoord; node_mem)
end
end
return array
Expand Down Expand Up @@ -243,6 +265,10 @@ function interpolate_column(
physical_z = false
end

# We need to call `length.(weights)...` because `weights` is a 1- or 2- tuple of
# `SMatrix`s and we want `node_mem` to have the same shape
node_mem = zeros(FT, length.(weights)...)

# If we physical_z, we have to move the z coordinates from physical to
# reference ones.
if physical_z
Expand All @@ -253,7 +279,8 @@ function interpolate_column(
z_surface = interpolate_slab(
space.hypsography.surface,
Fields.SlabIndex(nothing, gidx),
weights,
weights;
node_mem,
)
z_top = Spaces.vertical_topology(space).mesh.domain.coord_max.z
zpts_ref = [
Expand All @@ -265,7 +292,7 @@ function interpolate_column(
end

return [
z.z >= 0 ? interpolate_slab_level(field, gidx, weights, z) : FT(NaN) for
z in zpts_ref
z.z >= 0 ? interpolate_slab_level(field, gidx, weights, z; node_mem) :
FT(NaN) for z in zpts_ref
]
end

0 comments on commit c701b65

Please sign in to comment.