diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index c8c0ab7baf..171daf699e 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real) end """ - Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)]) + Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field)) Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place (i.e. it modifies the `f`). `ghost_buffer` contains the necessary information diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index 158f317b80..b57dbb6aa6 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -1,4 +1,5 @@ import BlockArrays +import ClimaCore.Utilities.UnrolledFunctions: unrolled_map, unrolled_foreach """ @@ -184,6 +185,43 @@ end return dest end +""" + Spaces.create_dss_buffer(fv::FieldVector) + +Create a NamedTuple of buffers for communicating neighbour information of +each Field in `fv`. In this NamedTuple, the name of each field is mapped +to the buffer. +""" +function Spaces.create_dss_buffer(fv::FieldVector) + NamedTuple{propertynames(fv)}( + unrolled_map( + key -> Spaces.create_dss_buffer(getproperty(fv, key)), + propertynames(fv), + ), + ) +end + +""" + Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv)) + +Apply weighted direct stiffness summation (DSS) to each field in `fv`. +Reuse the same `dss_buffer` for all fields in `fv`, which is constructed using +the first field in `fv` if it isn't passed explicitly. +""" +# TODO distribute `propertynames(fv)` over processes to parallelize `unrolled_foreach` +function Spaces.weighted_dss!( + fv::FieldVector, + dss_buffer = Spaces.create_dss_buffer(fv), +) + unrolled_foreach( + key -> Spaces.weighted_dss!( + getproperty(fv, key), + getproperty(dss_buffer, key), + ), + propertynames(fv), + ) +end + # Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer # see Base.Broadcast.preprocess_args @inline transform_bc_args(args::Tuple, inds...) = ( diff --git a/test/Fields/field_opt.jl b/test/Fields/field_opt.jl index 24955cdbb5..0f7d3386bc 100644 --- a/test/Fields/field_opt.jl +++ b/test/Fields/field_opt.jl @@ -393,4 +393,41 @@ using JET @test_opt ifelsekernel!(S, ρ) end +@testset "dss of FieldVectors" begin + function field_vec(center_space, face_space) + Y = Fields.FieldVector( + c = map(Fields.coordinate_field(center_space)) do coord + FT = Spaces.undertype(center_space) + (; ρ = FT(0), uₕ = Geometry.Covariant12Vector(FT(0), FT(0))) + end, + f = map(Fields.coordinate_field(face_space)) do coord + FT = Spaces.undertype(face_space) + (; w = Geometry.Covariant3Vector(FT(0))) + end, + ) + return Y + end + + fv = field_vec(toy_sphere(Float64)...) + + # Test that dss_buffer is created and has the correct keys and buffer types + dss_buffer = Spaces.create_dss_buffer(fv) + @test haskey(dss_buffer, :c) + @test haskey(dss_buffer, :f) + @test getproperty(dss_buffer, :c) isa Topologies.DSSBuffer + @test getproperty(dss_buffer, :f) isa Topologies.DSSBuffer + + c_copy = copy(getproperty(fv, :c)) + f_copy = copy(getproperty(fv, :f)) + + # Test weighted_dss! with and without preallocated buffer + p = @allocated Spaces.weighted_dss!(fv, dss_buffer) # DSS2 + @test getproperty(fv, :c) ≈ c_copy + @test getproperty(fv, :f) ≈ f_copy + + Spaces.weighted_dss!(fv) + @test getproperty(fv, :c) ≈ c_copy + @test getproperty(fv, :f) ≈ f_copy +end + nothing