Skip to content

Commit

Permalink
Switch to using Unrolled.jl and fix other FieldName compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Dec 6, 2023
1 parent 3ac7026 commit 62fb219
Show file tree
Hide file tree
Showing 21 changed files with 472 additions and 444 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
Expand Down Expand Up @@ -62,5 +63,6 @@ SparseArrays = "1"
Static = "0.4, 0.5, 0.6, 0.7, 0.8"
StaticArrays = "1"
Statistics = "1"
Unrolled = "0.1"
WeakValueDicts = "0.1"
julia = "1.9"
8 changes: 7 additions & 1 deletion benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.1"
Expand Down Expand Up @@ -1265,6 +1265,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2577,6 +2577,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 8 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
CurrentModule = ClimaCore
```

## Utilities

```@docs
Utilities.PlusHalf
Utilities.half
Utilities.UnrolledFunctions
```

## DataLayouts

```@docs
Expand Down
1 change: 1 addition & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ FieldNameSet
FieldNameDict
field_vector_view
concrete_field_vector
is_lazy
lazy_main_diagonal
lazy_mul
LazySchurComplement
Expand Down
8 changes: 7 additions & 1 deletion examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2128,6 +2128,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2194,6 +2194,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ..Topologies
import ..Grids: ColumnIndex
import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace
import ..Geometry: Geometry, Cartesian12Vector
import ..Utilities: PlusHalf, half
import ..Utilities: PlusHalf, half, UnrolledFunctions

using ..RecursiveApply
using CUDA
Expand Down
22 changes: 6 additions & 16 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,27 +260,17 @@ LinearAlgebra.ldiv!(A::LinearAlgebra.LU, x::FieldVector) =
x .= LinearAlgebra.ldiv!(A, Vector(x))

function LinearAlgebra.norm_sqr(x::FieldVector)
Base.sum(value -> LinearAlgebra.norm_sqr(backing_array(value)), _values(x))
value_norm_sqrs = UnrolledFunctions.unrolled_map(_values(x)) do value
LinearAlgebra.norm_sqr(backing_array(value))
end
return sum(value_norm_sqrs; init = zero(eltype(x)))
end
function LinearAlgebra.norm(x::FieldVector)
sqrt(LinearAlgebra.norm_sqr(x))
end

import ClimaComms

ClimaComms.array_type(x::FieldVector) = _array_type(x)

@inline _array_type(x::FieldVector) = _array_type(x, propertynames(x))
@inline _array_type(x::FieldVector, pns::Tuple{}) = Any

@inline _array_type(x::Field) = ClimaComms.array_type(x)
@inline _array_type(x::FieldVector, sym::Symbol) =
_array_type(getproperty(x, sym))

@inline _array_type(x::FieldVector, pns::Tuple{Symbol}) =
_array_type(getproperty(x, first(pns)))

@inline _array_type(x::FieldVector, pns::Tuple) = promote_type(
_array_type(getproperty(x, first(pns))),
_array_type(x, Base.tail(pns)),
ClimaComms.array_type(x::FieldVector) = promote_type(
UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))...,
)
6 changes: 3 additions & 3 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ import ..Spaces
import ..Fields
import ..Operators

using ..Utilities.UnrolledFunctions

export DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow
export FieldVectorKeys, FieldVectorView, FieldVectorViewBroadcasted
export FieldMatrixKeys, FieldMatrix, FieldMatrixBroadcasted
export FieldVectorKeys, FieldMatrixKeys, FieldVectorView, FieldMatrix
export , FieldMatrixSolver, field_matrix_solve!

# Types that are teated as single values when using matrix fields.
Expand All @@ -94,7 +95,6 @@ include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")
include("unrolled_functions.jl")
include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
Expand Down
29 changes: 12 additions & 17 deletions src/MatrixFields/field_matrix_iterative_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ is_diagonal(P_alg) = isnothing(solver_algorithm(P_alg))
"""
lazy_preconditioner(P_alg, A)
Constructs an un-materialized `FieldMatrixBroadcasted` (or just a `FieldMatrix`
when possible) that approximates `A` according to the `PreconditionerAlgorithm`
`P_alg`. If `P_alg` is `nothing` instead of a `PreconditionerAlgorithm`, this
returns `one(A)`.
Constructs a lazy `FieldMatrix` (or a concrete one when possible) that
approximates `A` according to the `PreconditionerAlgorithm` `P_alg`. If `P_alg`
is `nothing` instead of a `PreconditionerAlgorithm`, this returns `one(A)`.
"""
lazy_preconditioner(::Nothing, A::FieldMatrix) = one(A)

Expand All @@ -48,9 +47,8 @@ preconditioner generated by the `PreconditionerAlgorithm` `P_alg` for `A`.
function preconditioner_cache(P_alg, A, b)
is_diagonal(P_alg) && return (;)
lazy_P = lazy_preconditioner(P_alg, A)
is_lazy_P_concrete = !(lazy_P isa FieldMatrixBroadcasted)
P = is_lazy_P_concrete ? lazy_P : Base.Broadcast.materialize(lazy_P)
P_if_needed = is_lazy_P_concrete ? (;) : (; P)
P = is_lazy(lazy_P) ? Base.Broadcast.materialize(lazy_P) : lazy_P
P_if_needed = is_lazy(lazy_P) ? (; P) : (;)
x = similar_to_x(P, b)
alg = solver_algorithm(P_alg)
cache = field_matrix_solver_cache(alg, P, b)
Expand Down Expand Up @@ -79,27 +77,24 @@ end
"""
lazy_or_concrete_preconditioner(P_alg, P_cache, A)
A wrapper for [`lazy_preconditioner`](@ref) that turns the un-materialized
`FieldMatrixBroadcasted` `P` into a concrete `FieldMatrix` when the
`PreconditionerAlgorithm` `P_alg` requires a `FieldMatrixSolverAlgorithm` to
invert it.
A wrapper for [`lazy_preconditioner`](@ref) that turns the lazy `FieldMatrix`
`P` into a concrete `FieldMatrix` when the `PreconditionerAlgorithm` `P_alg`
requires a `FieldMatrixSolverAlgorithm` to invert it.
"""
function lazy_or_concrete_preconditioner(P_alg, P_cache, A)
isnothing(P_alg) && return nothing
lazy_P = lazy_preconditioner(P_alg, A)
is_lazy_P_concrete = !(lazy_P isa FieldMatrixBroadcasted)
(is_diagonal(P_alg) || is_lazy_P_concrete) && return lazy_P
(is_diagonal(P_alg) || !is_lazy(lazy_P)) && return lazy_P
@. P_cache.P = lazy_P
return P_cache.P
end

"""
apply_preconditioner(P_alg, P_cache, P, lazy_b)
Constructs an un-materialized `FieldMatrixBroadcasted` (or just a `FieldMatrix`
when possible) that represents the product `@. inv(P) * b`. Here, `lazy_b` is an
un-materialized `FieldVectorViewBroadcasted` (or a `FieldVectorView`) that
represents `b`.
Constructs a lazy `FieldMatrix` (or a concrete one when possible) that
represents the product `@. inv(P) * b`. Here, `lazy_b` is a (possibly lazy)
`FieldVectorView` that represents `b`.
"""
function apply_preconditioner(P_alg, P_cache, P, lazy_b)
isnothing(P_alg) && return lazy_b
Expand Down
15 changes: 7 additions & 8 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function FieldMatrixSolver(
b::Fields.FieldVector,
)
b_view = field_vector_view(b)
A_with_tree = FieldMatrix(pairs(A)...; keys(b_view).name_tree)
A_with_tree = replace_name_tree(A, keys(b_view).name_tree)
cache = field_matrix_solver_cache(alg, A_with_tree, b_view)
check_field_matrix_solver(alg, cache, A_with_tree, b_view)
return FieldMatrixSolver(alg, cache)
Expand All @@ -84,7 +84,7 @@ function field_matrix_solve!(
"The linear system cannot be solved because x and b have incompatible \
keys: $(set_string(keys(x_view))) vs. $(set_string(keys(b_view)))",
)
A_with_tree = FieldMatrix(pairs(A)...; keys(b_view).name_tree)
A_with_tree = replace_name_tree(A, keys(b_view).name_tree)
check_field_matrix_solver(alg, cache, A_with_tree, b_view)
run_field_matrix_solver!(alg, cache, x_view, A_with_tree, b_view)
return x
Expand Down Expand Up @@ -120,7 +120,7 @@ function similar_to_x(A, b)
entries = map(matrix_row_keys(keys(A))) do name
similar(b[name], x_eltype(A[name, name], b[name]))
end
return FieldVectorView(matrix_row_keys(keys(A)), entries)
return FieldNameDict(matrix_row_keys(keys(A)), entries)
end

################################################################################
Expand All @@ -135,10 +135,9 @@ lazy_sub(As...) = Base.Broadcast.broadcasted(-, As...)
"""
lazy_mul(A, args...)
Constructs an un-materialized `FieldMatrixBroadcasted` that represents the
product `@. *(A, args...)`. This involves regular broadcasting when `A` is a
`FieldMatrix` or `FieldMatrixBroadcasted`, but it has more complex behavior for
other objects like the [`LazySchurComplement`](@ref).
Constructs a lazy `FieldMatrix` that represents the product `@. *(A, args...)`.
This involves regular broadcasting when `A` is a `FieldMatrix`, but it has more
complex behavior for other objects like the [`LazySchurComplement`](@ref).
"""
lazy_mul(A, args...) = Base.Broadcast.broadcasted(*, A, args...)

Expand Down Expand Up @@ -234,7 +233,7 @@ function field_matrix_solver_cache(::BlockDiagonalSolve, A, b)
caches = map(matrix_row_keys(keys(A))) do name
single_field_solver_cache(A[name, name], b[name])
end
return FieldNameDict{FieldName}(matrix_row_keys(keys(A)), caches)
return FieldNameDict(matrix_row_keys(keys(A)), caches)
end

function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
Expand Down
51 changes: 25 additions & 26 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ is_child_name(
::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
length(child_name_chain) >= length(parent_name_chain) &&
child_name_chain[1:length(parent_name_chain)] == parent_name_chain
unrolled_take(child_name_chain, Val(length(parent_name_chain))) ==
parent_name_chain

names_are_overlapping(name1, name2) =
is_overlapping_name(name1, name2) =
is_child_name(name1, name2) || is_child_name(name2, name1)

extract_internal_name(
child_name::FieldName{child_name_chain},
parent_name::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
is_child_name(child_name, parent_name) ?
FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) :
error("$child_name is not a child name of $parent_name")
FieldName(
unrolled_drop(child_name_chain, Val(length(parent_name_chain)))...,
) : error("$child_name is not a child name of $parent_name")

append_internal_name(
::FieldName{name_chain},
Expand Down Expand Up @@ -118,41 +120,38 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <:
subtrees::S
end

FieldNameTree(x) = make_subtree_at_name(x, @name())
function make_subtree_at_name(x, name)
FieldNameTree(x) = subtree_at_name(x, @name())
function subtree_at_name(x, name)
internal_names = top_level_names(get_field(x, name))
isempty(internal_names) && return FieldNameTreeLeaf(name)
subsubtrees = unrolled_map(internal_names) do internal_name
make_subtree_at_name(x, append_internal_name(name, internal_name))
return if isempty(internal_names)
FieldNameTreeLeaf(name)
else
subsubtrees_at_name = unrolled_map(internal_names) do internal_name
subtree_at_name(x, append_internal_name(name, internal_name))
end
FieldNameTreeNode(name, subsubtrees_at_name)
end
return FieldNameTreeNode(name, subsubtrees)
end

is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name
is_valid_name(name, tree::FieldNameTreeNode) =
is_valid_name(name, tree) =
name == tree.name ||
is_child_name(name, tree.name) &&
tree isa FieldNameTreeNode &&
unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees)

function child_names(name, tree)
is_valid_name(name, tree) || error("$name is not a valid name")
subtree = get_subtree_at_name(name, tree)
subtree isa FieldNameTreeNode ||
error("FieldNameTree does not contain any child names for $name")
subtree isa FieldNameTreeNode || error("$name does not have child names")
return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees)
end
get_subtree_at_name(name, tree::FieldNameTreeLeaf) =
name == tree.name ? tree :
error("FieldNameTree does not contain the name $name")
get_subtree_at_name(name, tree::FieldNameTreeNode) =
get_subtree_at_name(name, tree) =
if name == tree.name
tree
elseif is_valid_name(name, tree)
subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree
is_child_name(name, subtree.name)
end
get_subtree_at_name(name, subtree_that_contains_name)
else
error("FieldNameTree does not contain the name $name")
subtree = unrolled_findonly(tree.subtrees) do subtree
is_valid_name(name, subtree)
end
get_subtree_at_name(name, subtree)
end

################################################################################
Expand All @@ -175,7 +174,7 @@ if hasfield(Method, :recursion_relation)
for m in methods(wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(make_subtree_at_name)
for m in methods(subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(is_valid_name)
Expand Down
Loading

0 comments on commit 62fb219

Please sign in to comment.