From a96619f3e0ad1c8dbe29c935986153335002bdf8 Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Fri, 20 Sep 2024 11:59:13 -0700 Subject: [PATCH] extend dss functions for FieldVectors --- NEWS.md | 3 +++ src/Fields/Fields.jl | 2 +- src/Fields/fieldvector.jl | 36 +++++++++++++++++++++++++++++++ test/Fields/field_opt.jl | 45 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 004b157664..bca8ccb1a3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes main ------- +- Extended `create_dss_buffer` and `weighted_dss!` for `FieldVector`s, rather than +just `Field`s. PR [#2000](https://github.com/CliMA/ClimaCore.jl/pull/2000). + v0.14.16 ------- diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index c8c0ab7baf..171daf699e 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real) end """ - Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)]) + Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field)) Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place (i.e. it modifies the `f`). `ghost_buffer` contains the necessary information diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index 158f317b80..b86962782a 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -1,4 +1,5 @@ import BlockArrays +import ClimaCore.Utilities.UnrolledFunctions: unrolled_map, unrolled_foreach """ @@ -184,6 +185,41 @@ end return dest end +""" + Spaces.create_dss_buffer(fv::FieldVector) + +Create a NamedTuple of buffers for communicating neighbour information of +each Field in `fv`. In this NamedTuple, the name of each field is mapped +to the buffer. +""" +function Spaces.create_dss_buffer(fv::FieldVector) + NamedTuple{propertynames(fv)}( + unrolled_map( + key -> Spaces.create_dss_buffer(getproperty(fv, key)), + propertynames(fv), + ), + ) +end + +""" + Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv)) + +Apply weighted direct stiffness summation (DSS) to each field in `fv`. +If a `dss_buffer` object is not provided, a buffer will be created for each +field in `fv`. +Note that using the `Pair` interface here parallelizes the `weighted_dss!` calls. +""" +function Spaces.weighted_dss!( + fv::FieldVector, + dss_buffer = Spaces.create_dss_buffer(fv), +) + pairs = map(propertynames(fv)) do pn + Pair(getproperty(fv, key), getproperty(dss_buffer, key)) + end + Spaces.weighted_dss!(pairs...) +end + + # Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer # see Base.Broadcast.preprocess_args @inline transform_bc_args(args::Tuple, inds...) = ( diff --git a/test/Fields/field_opt.jl b/test/Fields/field_opt.jl index 24955cdbb5..b57c4f3206 100644 --- a/test/Fields/field_opt.jl +++ b/test/Fields/field_opt.jl @@ -393,4 +393,49 @@ using JET @test_opt ifelsekernel!(S, ρ) end +@testset "dss of FieldVectors" begin + function field_vec(center_space, face_space) + Y = Fields.FieldVector( + c = map(Fields.coordinate_field(center_space)) do coord + FT = Spaces.undertype(center_space) + (; + ρ = FT(coord.lat + coord.long), + uₕ = Geometry.Covariant12Vector( + FT(coord.lat), + FT(coord.long), + ), + ) + end, + f = map(Fields.coordinate_field(face_space)) do coord + FT = Spaces.undertype(face_space) + (; w = Geometry.Covariant3Vector(FT(coord.lat + coord.long))) + end, + ) + return Y + end + + fv = field_vec(toy_sphere(Float64)...) + + c_copy = copy(getproperty(fv, :c)) + f_copy = copy(getproperty(fv, :f)) + + # Test that dss_buffer is created and has the correct keys + dss_buffer = Spaces.create_dss_buffer(fv) + @test haskey(dss_buffer, :c) + @test haskey(dss_buffer, :f) + + # Test weighted_dss! with and without preallocated buffer + Spaces.weighted_dss!(fv, dss_buffer) + @test getproperty(fv, :c) ≈ Spaces.weighted_dss!(c_copy) + @test getproperty(fv, :f) ≈ Spaces.weighted_dss!(f_copy) + + fv = field_vec(toy_sphere(Float64)...) + c_copy = copy(getproperty(fv, :c)) + f_copy = copy(getproperty(fv, :f)) + + Spaces.weighted_dss!(fv) + @test getproperty(fv, :c) ≈ Spaces.weighted_dss!(c_copy) + @test getproperty(fv, :f) ≈ Spaces.weighted_dss!(f_copy) +end + nothing