Skip to content

Commit

Permalink
extend dss functions for FieldVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
juliasloan25 committed Sep 20, 2024
1 parent bd20629 commit 7edba4c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real)
end

"""
Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)])
Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field))
Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place
(i.e. it modifies the `f`). `ghost_buffer` contains the necessary information
Expand Down
38 changes: 38 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import BlockArrays
import ClimaCore.Utilities.UnrolledFunctions: unrolled_map, unrolled_foreach


"""
Expand Down Expand Up @@ -184,6 +185,43 @@ end
return dest
end

"""
Spaces.create_dss_buffer(fv::FieldVector)
Create a NamedTuple of buffers for communicating neighbour information of
each Field in `fv`. In this NamedTuple, the name of each field is mapped
to the buffer.
"""
function Spaces.create_dss_buffer(fv::FieldVector)
NamedTuple{propertynames(fv)}(
unrolled_map(
key -> Spaces.create_dss_buffer(getproperty(fv, key)),
propertynames(fv),
),
)
end

"""
Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv))
Apply weighted direct stiffness summation (DSS) to each field in `fv`.
Reuse the same `dss_buffer` for all fields in `fv`, which is constructed using
the first field in `fv` if it isn't passed explicitly.
"""
# TODO distribute `propertynames(fv)` over processes to parallelize `unrolled_foreach`
function Spaces.weighted_dss!(
fv::FieldVector,
dss_buffer = Spaces.create_dss_buffer(fv),
)
unrolled_foreach(
key -> Spaces.weighted_dss!(
getproperty(fv, key),
getproperty(dss_buffer, key),
),
propertynames(fv),
)
end

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_bc_args(args::Tuple, inds...) = (
Expand Down
37 changes: 37 additions & 0 deletions test/Fields/field_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,41 @@ using JET
@test_opt ifelsekernel!(S, ρ)
end

@testset "dss of FieldVectors" begin
function field_vec(center_space, face_space)
Y = Fields.FieldVector(
c = map(Fields.coordinate_field(center_space)) do coord
FT = Spaces.undertype(center_space)
(; ρ = FT(0), uₕ = Geometry.Covariant12Vector(FT(0), FT(0)))
end,
f = map(Fields.coordinate_field(face_space)) do coord
FT = Spaces.undertype(face_space)
(; w = Geometry.Covariant3Vector(FT(0)))
end,
)
return Y
end

fv = field_vec(toy_sphere(Float64)...)

# Test that dss_buffer is created and has the correct keys and buffer types
dss_buffer = Spaces.create_dss_buffer(fv)
@test haskey(dss_buffer, :c)
@test haskey(dss_buffer, :f)
@test getproperty(dss_buffer, :c) isa Topologies.DSSBuffer
@test getproperty(dss_buffer, :f) isa Topologies.DSSBuffer

c_copy = copy(getproperty(fv, :c))
f_copy = copy(getproperty(fv, :f))

# Test weighted_dss! with and without preallocated buffer
p = @allocated Spaces.weighted_dss!(fv, dss_buffer) # DSS2
@test getproperty(fv, :c) c_copy
@test getproperty(fv, :f) f_copy

Spaces.weighted_dss!(fv)
@test getproperty(fv, :c) c_copy
@test getproperty(fv, :f) f_copy
end

nothing

0 comments on commit 7edba4c

Please sign in to comment.