From b7d39ea4859df00c31b76e47f514d3c9d3b30923 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Wed, 16 Aug 2023 17:39:38 -0700 Subject: [PATCH] Add FieldMatrix and linear solvers --- .buildkite/pipeline.yml | 19 +- docs/src/matrix_fields.md | 18 + src/MatrixFields/MatrixFields.jl | 36 +- src/MatrixFields/band_matrix_row.jl | 6 + src/MatrixFields/field2arrays.jl | 4 +- src/MatrixFields/field_matrix_solver.jl | 308 ++++++++++ src/MatrixFields/field_name.jl | 187 ++++++ src/MatrixFields/field_name_dict.jl | 471 +++++++++++++++ src/MatrixFields/field_name_set.jl | 426 +++++++++++++ src/MatrixFields/operator_matrices.jl | 2 +- src/MatrixFields/single_field_solver.jl | 211 +++++++ src/MatrixFields/unrolled_functions.jl | 91 +++ src/RecursiveApply/RecursiveApply.jl | 3 - test/MatrixFields/field_matrix_solvers.jl | 381 ++++++++++++ test/MatrixFields/field_names.jl | 591 +++++++++++++++++++ test/MatrixFields/matrix_field_test_utils.jl | 23 +- test/runtests.jl | 2 + 17 files changed, 2750 insertions(+), 29 deletions(-) create mode 100644 src/MatrixFields/field_matrix_solver.jl create mode 100644 src/MatrixFields/field_name.jl create mode 100644 src/MatrixFields/field_name_dict.jl create mode 100644 src/MatrixFields/field_name_set.jl create mode 100644 src/MatrixFields/single_field_solver.jl create mode 100644 src/MatrixFields/unrolled_functions.jl create mode 100644 test/MatrixFields/field_matrix_solvers.jl create mode 100644 test/MatrixFields/field_names.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index dbb5b4865b..1cd158f2c2 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -556,7 +556,7 @@ steps: soft_fail: true agents: slurm_gpus: 1 - slurm_mem: 40GB + slurm_mem: 80GB - label: "Unit: operator matrices (CPU)" key: unit_operator_matrices_cpu @@ -570,6 +570,23 @@ steps: slurm_gpus: 1 slurm_mem: 40GB + - label: "Unit: field names" + key: unit_field_names + command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_names.jl" + + - label: "Unit: field matrix solvers (CPU)" + key: unit_field_matrix_solvers_cpu + command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_matrix_solvers.jl" + agents: + slurm_mem: 40GB + + - label: "Unit: field matrix solvers (GPU)" + key: unit_field_matrix_solvers_gpu + command: "julia --color=yes --project=test test/MatrixFields/field_matrix_solvers.jl" + agents: + slurm_gpus: 1 + slurm_mem: 40GB + - group: "Unit: Hypsography" steps: diff --git a/docs/src/matrix_fields.md b/docs/src/matrix_fields.md index ec912341c1..3a1fa51b78 100644 --- a/docs/src/matrix_fields.md +++ b/docs/src/matrix_fields.md @@ -26,6 +26,18 @@ MultiplyColumnwiseBandMatrixField operator_matrix ``` +# Linear Solvers + +```@docs +FieldMatrixSolverAlgorithm +FieldMatrixSolver +field_matrix_solve! +BlockDiagonalSolve +BlockLowerTriangularSolve +SchurComplementSolve +ApproximateFactorizationSolve +``` + ## Internals ```@docs @@ -39,6 +51,12 @@ matrix_shape column_axes AbstractLazyOperator replace_lazy_operator +FieldName +@name +FieldNameTree +FieldNameSet +FieldNameDict +field_vector_view ``` ## Utilities diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 11e75e8339..a29920aaee 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -24,41 +24,41 @@ for them: `BandedMatrices.jl` - Custom printing, e.g., `matrix_field` gets displayed as a `BandedMatrix`, specifically, as the `BandedMatrix` that corresponds to its first column + +TODO: Add comments on `operator_matrix`, `FieldMatrix`, and `FieldMatrixSolver`. """ module MatrixFields -import CUDA: @allowscalar -import LinearAlgebra: UniformScaling, Adjoint, AdjointAbsVec +import CUDA +import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix +import ClimaComms import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv +import ..RecursiveApply: ⊠, ⊞, ⊟ import ..DataLayouts: AbstractData import ..Geometry import ..Spaces import ..Fields import ..Operators -export ⋅ export DiagonalMatrixRow, BidiagonalMatrixRow, TridiagonalMatrixRow, QuaddiagonalMatrixRow, PentadiagonalMatrixRow +export FieldVectorKeys, FieldVectorView, FieldVectorViewBroadcasted +export FieldMatrixKeys, FieldMatrix, FieldMatrixBroadcasted +export @name, ⋅, FieldMatrixSolver, field_matrix_solve! # Types that are teated as single values when using matrix fields. const SingleValue = Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor} include("band_matrix_row.jl") -include("rmul_with_projection.jl") -include("matrix_shape.jl") -include("matrix_multiplication.jl") -include("lazy_operators.jl") -include("operator_matrices.jl") -include("field2arrays.jl") const ColumnwiseBandMatrixField{V, S} = Fields.Field{ V, @@ -71,6 +71,19 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{ }, } +include("rmul_with_projection.jl") +include("matrix_shape.jl") +include("matrix_multiplication.jl") +include("lazy_operators.jl") +include("operator_matrices.jl") +include("field2arrays.jl") +include("unrolled_functions.jl") +include("field_name.jl") +include("field_name_set.jl") +include("field_name_dict.jl") +include("field_matrix_solver.jl") +include("single_field_solver.jl") + function Base.show(io::IO, field::ColumnwiseBandMatrixField) print(io, eltype(field), "-valued Field") if eltype(eltype(field)) <: Number @@ -82,7 +95,10 @@ function Base.show(io::IO, field::ColumnwiseBandMatrixField) end column_field = Fields.column(field, 1, 1, 1) io = IOContext(io, :compact => true, :limit => true) - @allowscalar Base.print_array(io, column_field2array_view(column_field)) + CUDA.@allowscalar Base.print_array( + io, + column_field2array_view(column_field), + ) else # When a BandedMatrix with non-number entries is printed, it currently # either prints in an illegible format (e.g., if it has AxisTensor or diff --git a/src/MatrixFields/band_matrix_row.jl b/src/MatrixFields/band_matrix_row.jl index 807bea4001..8af47890c2 100644 --- a/src/MatrixFields/band_matrix_row.jl +++ b/src/MatrixFields/band_matrix_row.jl @@ -135,3 +135,9 @@ Base.:*(value::SingleValue, row::BandMatrixRow) = Base.:/(row::BandMatrixRow, value::Number) = map(entry -> rdiv(entry, value), row) + +inv(row::DiagonalMatrixRow) = DiagonalMatrixRow(inv(row[0])) +inv(::BandMatrixRow{ld, bw}) where {ld, bw} = error( + "The inverse of a matrix with $bw diagonals is (usually) a dense matrix, \ + so it cannot be represented using BandMatrixRows", +) diff --git a/src/MatrixFields/field2arrays.jl b/src/MatrixFields/field2arrays.jl index 4108d24c75..4a610ad65b 100644 --- a/src/MatrixFields/field2arrays.jl +++ b/src/MatrixFields/field2arrays.jl @@ -53,11 +53,11 @@ function column_field2array(field::Fields.FiniteDifferenceField) last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d diagonal_data_view = view(diagonal_data, first_row:last_row) - @allowscalar copyto!(matrix_diagonal, diagonal_data_view) + CUDA.@allowscalar copyto!(matrix_diagonal, diagonal_data_view) end return matrix else # field represents a vector - return @allowscalar Array(column_field2array_view(field)) + return CUDA.@allowscalar Array(column_field2array_view(field)) end end diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl new file mode 100644 index 0000000000..dbe2675e9d --- /dev/null +++ b/src/MatrixFields/field_matrix_solver.jl @@ -0,0 +1,308 @@ +""" + FieldMatrixSolverAlgorithm + +Description of how to solve an equation of the form `A * x = b` for `x`, where +`A` is a `FieldMatrix` and where `x` and `b` are both `FieldVector`s. Different +algorithms can be nested inside each other, enabling the construction of +specialized linear solvers that fully utilize the sparsity pattern of `A`. +""" +abstract type FieldMatrixSolverAlgorithm end + +""" + FieldMatrixSolver(alg, A, b) + +Combination of a `FieldMatrixSolverAlgorithm` and the cache that it requires to +solve the equation `A * x = b` for `x`. The values of `A` and `b` that get +passed to this constructor should be `similar` to the ones that get passed to +`field_matrix_solve!` in order to ensure that the cache gets allocated +correctly. +""" +struct FieldMatrixSolver{A <: FieldMatrixSolverAlgorithm, C} + alg::A + cache::C +end +function FieldMatrixSolver( + alg::FieldMatrixSolverAlgorithm, + A::FieldMatrix, + b::Fields.FieldVector, +) + b_view = field_vector_view(b) + cache = field_matrix_solver_cache(alg, A, b_view) + check_field_matrix_solver(alg, cache, A, b_view) + return FieldMatrixSolver(alg, cache) +end + +""" + field_matrix_solve!(solver, x, A, b) + +Solves the equation `A * x = b` for `x` using the given `FieldMatrixSolver`. +""" +function field_matrix_solve!( + solver::FieldMatrixSolver, + x::Fields.FieldVector, + A::FieldMatrix, + b::Fields.FieldVector, +) + x_view = field_vector_view(x) + b_view = field_vector_view(b) + keys(x_view) == keys(b_view) || error( + "The linear system cannot be solved because x and b have incompatible \ + keys: $(set_string(keys(x_view))) vs. $(set_string(keys(b_view)))", + ) + check_field_matrix_solver(solver.alg, solver.cache, A, b_view) + field_matrix_solve!(solver.alg, solver.cache, x_view, A, b_view) + return x +end + +function check_block_diagonal_matrix_has_no_missing_blocks(A, b) + rows_with_missing_blocks = + setdiff(keys(b), matrix_row_keys(matrix_diagonal_keys(keys(A)))) + missing_keys = corresponding_matrix_keys(rows_with_missing_blocks) + # The missing keys correspond to zeros, and det(A) = 0 when A is a block + # diagonal matrix with zeros along its diagonal. We can only solve A * x = b + # if det(A) != 0, so we throw an error whenever there are missing keys. + # Although it might still be the case that det(A) = 0 even if there are no + # missing keys, this cannot be inferred during compilation. + isempty(missing_keys) || + error("The linear system cannot be solved because A does not have any \ + entries at the following keys: $(set_string(missing_keys))") +end + +function partition_blocks(names₁, A, b, x = nothing) + keys₁ = FieldVectorKeys(names₁, keys(b).name_tree) + keys₂ = set_complement(keys₁) + A₁₁ = A[cartesian_product(keys₁, keys₁)] + A₁₂ = A[cartesian_product(keys₁, keys₂)] + A₂₁ = A[cartesian_product(keys₂, keys₁)] + A₂₂ = A[cartesian_product(keys₂, keys₂)] + return isnothing(x) ? (A₁₁, A₁₂, A₂₁, A₂₂, b[keys₁], b[keys₂]) : + (A₁₁, A₁₂, A₂₁, A₂₂, b[keys₁], b[keys₂], x[keys₁], x[keys₂]) +end + +################################################################################ + +""" + BlockDiagonalSolve() + +A `FieldMatrixSolverAlgorithm` for a block diagonal matrix `A`, which solves +each block's equation `Aᵢᵢ * xᵢ = bᵢ` in sequence. The equation for `xᵢ` is +solved as follows: +- If `Aᵢᵢ = λᵢ * I`, the equation is solved by setting `xᵢ .= inv(λᵢ) .* bᵢ`. +- If `Aᵢᵢ = Dᵢ`, where `Dᵢ` is a diagonal matrix, the equation is solved by + making a single pass over the data, setting each `xᵢ[n] = inv(Dᵢ[n]) * bᵢ[n]`. +- If `Aᵢᵢ = Lᵢ * Dᵢ * Uᵢ`, where `Dᵢ` is a diagonal matrix and where `Lᵢ` and + `Uᵢ` are unit lower and upper triangular matrices, respectively, the equation + is solved using Gauss-Jordan elimination, which makes two passes over the + data. The first pass multiplies both sides of the equation by `inv(Lᵢ * Dᵢ)`, + replacing `Aᵢᵢ` with `Uᵢ` and `bᵢ` with `Uᵢxᵢ`, which is also referred to as + putting `Aᵢᵢ` into "reduced row echelon form". The second pass solves + `Uᵢ * xᵢ = Uᵢxᵢ` for `xᵢ` using a unit upper triangular matrix solver, which + is also referred to as "back substitution". Only tri-diagonal and + penta-diagonal matrices `Aᵢᵢ` are currently supported. +- The general case of `Aᵢᵢ = inv(Pᵢ) * Lᵢ * Uᵢ`, where `Pᵢ` is a row permutation + matrix (i.e., LU factorization with partial pivoting), is not currently + supported. +""" +struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end + +function field_matrix_solver_cache(::BlockDiagonalSolve, A, b) + caches = map(matrix_row_keys(keys(A))) do name + single_field_solver_cache(A[(name, name)], b[name]) + end + return FieldNameDict{FieldName}(matrix_row_keys(keys(A)), caches) +end + +function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b) + check_block_diagonal_matrix( + A, + "BlockDiagonalSolve cannot be used because A", + ) + check_block_diagonal_matrix_has_no_missing_blocks(A, b) + foreach(matrix_row_keys(keys(A))) do name + check_single_field_solver(A[(name, name)], b[name]) + end +end + +field_matrix_solve!(::BlockDiagonalSolve, cache, x, A, b) = + foreach(matrix_row_keys(keys(A))) do name + single_field_solve!(cache[name], x[name], A[(name, name)], b[name]) + end + +""" + BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂]) + +A `FieldMatrixSolverAlgorithm` for a block lower triangular matrix `A`, which +solves for `x` by executing the following steps: +1. Partition the entries in `A`, `x`, and `b` into the blocks `A₁₁`, `A₁₂`, + `A₂₁`, `A₂₂`, `x₁`, `x₂`, `b₁`, and `b₂`, based on the `FieldName`s in + `names₁`. In this notation, the subscript `₁` corresponds to `FieldName`s + that are covered by `names₁`, while the subscript `₂` corresponds to all + other `FieldNames`. A subscript in the first position refers to `FieldName`s + that are used as row indices, while a subscript in the second position refers + to column indices. This algorithm requires that the upper triangular block + `A₁₂` be empty. (Any upper triangular solve can also be expressed as a lower + triangular solve by swapping the subscripts `₁` and `₂`.) +2. Solve `A₁₁ * x₁ = b₁` for `x₁` using the algorithm `alg₁`, which is set to + `BlockDiagonalSolve()` by default. +3. Solve `A₂₂ * x₂ = b₂ - A₂₁ * x₁` for `x₂` using the algorithm `alg₂`, which + is set to `BlockDiagonalSolve()` by default. +""" +struct BlockLowerTriangularSolve{ + V <: NTuple{<:Any, FieldName}, + A1 <: FieldMatrixSolverAlgorithm, + A2 <: FieldMatrixSolverAlgorithm, +} <: FieldMatrixSolverAlgorithm + names₁::V + alg₁::A1 + alg₂::A2 +end +BlockLowerTriangularSolve( + names₁::FieldName...; + alg₁ = BlockDiagonalSolve(), + alg₂ = BlockDiagonalSolve(), +) = BlockLowerTriangularSolve(names₁, alg₁, alg₂) + +function field_matrix_solver_cache(alg::BlockLowerTriangularSolve, A, b) + A₁₁, _, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) + cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁) + b₂′ = similar(b₂) + cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂, b₂′) + return (; cache₁, b₂′, cache₂) +end + +function check_field_matrix_solver(alg::BlockLowerTriangularSolve, cache, A, b) + A₁₁, A₁₂, _, A₂₂, b₁, _ = partition_blocks(alg.names₁, A, b) + isempty(keys(A₁₂)) || error( + "BlockLowerTriangularSolve cannot be used because A has entries at the \ + following upper triangular keys: $(set_string(keys(A₁₂)))", + ) + check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁₁, b₁) + check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂₂, cache.b₂′) +end + +function field_matrix_solve!(alg::BlockLowerTriangularSolve, cache, x, A, b) + A₁₁, _, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x) + field_matrix_solve!(alg.alg₁, cache.cache₁, x₁, A₁₁, b₁) + @. cache.b₂′ = b₂ - A₂₁ * x₁ + field_matrix_solve!(alg.alg₂, cache.cache₂, x₂, A₂₂, cache.b₂′) +end + +""" + SchurComplementSolve(names₁...; [alg₁]) + +A `FieldMatrixSolverAlgorithm` for a block matrix `A`, which solves for `x` by +executing the following steps: +1. Partition the entries in `A`, `x`, and `b` into the blocks `A₁₁`, `A₁₂`, + `A₂₁`, `A₂₂`, `x₁`, `x₂`, `b₁`, and `b₂`, based on the `FieldName`s in + `names₁`. In this notation, the subscript `₁` corresponds to `FieldName`s + that are covered by `names₁`, while the subscript `₂` corresponds to all + other `FieldNames`. A subscript in the first position refers to `FieldName`s + that are used as row indices, while a subscript in the second position refers + to column indices. This algorithm requires that the block `A₂₂` be a diagonal + matrix, which allows it to assume that `inv(A₂₂)` can be computed on the fly. +2. Solve `(A₁₁ - A₁₂ * inv(A₂₂) * A₂₁) * x₁ = b₁ - A₁₂ * inv(A₂₂) * b₂` for `x₁` + using the algorithm `alg₁`, which is set to `BlockDiagonalSolve()` by + default. The matrix `A₁₁ - A₁₂ * inv(A₂₂) * A₂₁` is called the "Schur + complement" of `A₂₂` in `A`. +3. Set `x₂` to `inv(A₂₂) * (b₂ - A₂₁ * x₁)`. +""" +struct SchurComplementSolve{ + V <: NTuple{<:Any, FieldName}, + A <: FieldMatrixSolverAlgorithm, +} <: FieldMatrixSolverAlgorithm + names₁::V + alg₁::A +end +SchurComplementSolve(names₁::FieldName...; alg₁ = BlockDiagonalSolve()) = + SchurComplementSolve(names₁, alg₁) + +function field_matrix_solver_cache(alg::SchurComplementSolve, A, b) + A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b) + A₁₁′ = @. A₁₁ - A₁₂ * inv(A₂₂) * A₂₁ # A₁₁′ could have more blocks than A₁₁ + b₁′ = similar(b₁) + cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁′, b₁′) + return (; A₁₁′, b₁′, cache₁) +end + +function check_field_matrix_solver(alg::SchurComplementSolve, cache, A, b) + _, _, _, A₂₂, _, b₂ = partition_blocks(alg.names₁, A, b) + check_diagonal_matrix(A₂₂, "SchurComplementSolve cannot be used because A") + check_block_diagonal_matrix_has_no_missing_blocks(A₂₂, b₂) + check_field_matrix_solver(alg.alg₁, cache.cache₁, cache.A₁₁′, cache.b₁′) +end + +function field_matrix_solve!(alg::SchurComplementSolve, cache, x, A, b) + A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x) + @. cache.A₁₁′ = A₁₁ - A₁₂ * inv(A₂₂) * A₂₁ + @. cache.b₁′ = b₁ - A₁₂ * inv(A₂₂) * b₂ + field_matrix_solve!(alg.alg₁, cache.cache₁, x₁, cache.A₁₁′, cache.b₁′) + @. x₂ = inv(A₂₂) * (b₂ - A₂₁ * x₁) +end + +""" + ApproximateFactorizationSolve(name_pairs₁...; [alg₁], [alg₂]) + +A `FieldMatrixSolverAlgorithm` for a block matrix `A`, which (approximately) +solves for `x` by executing the following steps: +1. Use the entries in `A = M + I = M₁ + M₂ + I` to compute `A₁ = M₁ + I` and + `A₂ = M₂ + I`, based on the pairs of `FieldName`s in `name_pairs₁`. In this + notation, the subscript `₁` refers to pairs of `FieldName`s that are covered + by `name_pairs₁`, while the subscript `₂` refers to all other pairs of + `FieldNames`s. This algorithm approximates the matrix `A` as the product + `A₁ * A₂`, which introduces an error that scales roughly with the norm of + `A₁ * A₂ - A = M₁ * M₂`. (More precisely, the error introduced by this + algorithm is `x_exact - x_approx = inv(A) * b - inv(A₁ * A₂) * b`.) +2. Solve `A₁ * A₂x = b` for `A₂x` using the algorithm `alg₁`, which is set to + `BlockDiagonalSolve()` by default. +3. Solve `A₂ * x = A₂x` for `x` using the algorithm `alg₂`, which is set to + `BlockDiagonalSolve()` by default. +""" +struct ApproximateFactorizationSolve{ + V <: NTuple{<:Any, FieldNamePair}, + A1 <: FieldMatrixSolverAlgorithm, + A2 <: FieldMatrixSolverAlgorithm, +} <: FieldMatrixSolverAlgorithm + name_pairs₁::V + alg₁::A1 + alg₂::A2 +end +ApproximateFactorizationSolve( + name_pairs₁::FieldNamePair...; + alg₁ = BlockDiagonalSolve(), + alg₂ = BlockDiagonalSolve(), +) = ApproximateFactorizationSolve(name_pairs₁, alg₁, alg₂) +# Note: This algorithm assumes that x is `similar` to b. In other words, it +# assumes that typeof(x) == typeof(b), rather than just keys(x) == keys(b). + +function approximate_factors(name_pairs₁, A, b) + keys₁ = FieldMatrixKeys(name_pairs₁, keys(b).name_tree) + keys₂ = set_complement(keys₁) + A₁ = A[keys₁] .+ one(A)[keys₂] # `one` can be used because x is similar to b + A₂ = A[keys₂] .+ one(A)[keys₁] + return A₁, A₂ +end + +function field_matrix_solver_cache(alg::ApproximateFactorizationSolve, A, b) + A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) + cache₁ = field_matrix_solver_cache(alg.alg₁, A₁, b) + A₂x = @. A₂ * b # x can be replaced with b because they are similar + cache₂ = field_matrix_solver_cache(alg.alg₂, A₂, A₂x) + return (; cache₁, A₂x, cache₂) +end + +function check_field_matrix_solver( + alg::ApproximateFactorizationSolve, + cache, + A, + b, +) + A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) + check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁, b) + check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂, cache.A₂x) +end + +function field_matrix_solve!(alg::ApproximateFactorizationSolve, cache, x, A, b) + A₁, A₂ = approximate_factors(alg.name_pairs₁, A, b) + field_matrix_solve!(alg.alg₁, cache.cache₁, cache.A₂x, A₁, b) + field_matrix_solve!(alg.alg₂, cache.cache₂, x, A₂, cache.A₂x) +end diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl new file mode 100644 index 0000000000..f027b1c1e1 --- /dev/null +++ b/src/MatrixFields/field_name.jl @@ -0,0 +1,187 @@ +""" + FieldName(name_chain...) + +A singleton type that represents a chain of `getproperty` calls, which can be +used to access a property or sub-property of an object `x` using the function +`get_field(x, name)`. The entire object `x` can also be accessed with the empty +`FieldName()`. +""" +struct FieldName{name_chain} end +FieldName() = FieldName{()}() # This is required for type stability. +FieldName(name_chain...) = FieldName{name_chain}() + +""" + @name(expr) + +Shorthand for constructing a `FieldName`. Some examples include +- `name = @name()`, in which case `get_field(x, name)` returns `x` +- `name = @name(a)`, in which case `get_field(x, name)` returns `x.a` +- `name = @name(a.b.c)`, in which case `get_field(x, name)` returns `x.a.b.c` +- `name = @name(a.b.c.:(1).d)`, in which case `get_field(x, name)` returns + `x.a.b.c.:(1).d` + +This macro is preferred over the `FieldName` constructor because it checks +whether `expr` is a syntactically valid chain of `getproperty` calls before +calling the constructor. +""" +macro name() + return :(FieldName()) +end +macro name(expr) + return :(FieldName($(expr_to_name_chain(expr))...)) +end + +expr_to_name_chain(value) = error("$(repr(value)) is not a valid property name") +expr_to_name_chain(value::Union{Symbol, Integer}) = (value,) +expr_to_name_chain(quote_node::QuoteNode) = expr_to_name_chain(quote_node.value) +function expr_to_name_chain(expr::Expr) + expr.head == :. || error("$expr is not a valid property name") + arg1, arg2 = expr.args + return (expr_to_name_chain(arg1)..., expr_to_name_chain(arg2)...) +end + +# Show a FieldName with @name syntax, instead of the default constructor syntax. +function Base.show(io::IO, ::FieldName{name_chain}) where {name_chain} + quoted_names = map(name -> name isa Integer ? ":($name)" : name, name_chain) + print(io, "@name($(join(quoted_names, '.')))") +end + +extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain) +drop_first(::FieldName{name_chain}) where {name_chain} = + FieldName(Base.tail(name_chain)...) + +has_field(x, ::FieldName{()}) = true +has_field(x, name::FieldName) = + extract_first(name) in propertynames(x) && + has_field(getproperty(x, extract_first(name)), drop_first(name)) + +get_field(x, ::FieldName{()}) = x +get_field(x, name::FieldName) = + get_field(getproperty(x, extract_first(name)), drop_first(name)) + +broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true +broadcasted_has_field(::Type{X}, name::FieldName) where {X} = + extract_first(name) in fieldnames(X) && + broadcasted_has_field(fieldtype(X, extract_first(name)), drop_first(name)) + +broadcasted_get_field(x, ::FieldName{()}) = x +broadcasted_get_field(x, name::FieldName) = + broadcasted_get_field(getfield(x, extract_first(name)), drop_first(name)) + +is_child_name( + ::FieldName{child_name_chain}, + ::FieldName{parent_name_chain}, +) where {child_name_chain, parent_name_chain} = + length(child_name_chain) >= length(parent_name_chain) && + child_name_chain[1:length(parent_name_chain)] == parent_name_chain + +names_are_overlapping(name1, name2) = + is_child_name(name1, name2) || is_child_name(name2, name1) + +extract_internal_name( + child_name::FieldName{child_name_chain}, + parent_name::FieldName{parent_name_chain}, +) where {child_name_chain, parent_name_chain} = + is_child_name(child_name, parent_name) ? + FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) : + error("$child_name is not a child name of $parent_name") + +append_internal_name( + ::FieldName{name_chain}, + ::FieldName{internal_name_chain}, +) where {name_chain, internal_name_chain} = + FieldName(name_chain..., internal_name_chain...) + +top_level_names(x) = wrapped_prop_names(Val(propertynames(x))) +wrapped_prop_names(::Val{()}) = () +wrapped_prop_names(::Val{prop_names}) where {prop_names} = ( + FieldName(first(prop_names)), + wrapped_prop_names(Val(Base.tail(prop_names)))..., +) + +################################################################################ + +""" + FieldNameTree(x) + +Tree of `FieldName`s that can be used to access `x` with `get_field(x, name)`. +Check whether a `name` is valid by calling `is_valid_name(name, tree)`, +and extract the children of `name` by calling `child_names(name, tree)`. +""" +abstract type FieldNameTree end +struct FieldNameTreeLeaf{V <: FieldName} <: FieldNameTree + name::V +end +struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <: + FieldNameTree + name::V + subtrees::S +end + +FieldNameTree(x) = make_subtree_at_name(x, @name()) +function make_subtree_at_name(x, name) + internal_names = top_level_names(get_field(x, name)) + isempty(internal_names) && return FieldNameTreeLeaf(name) + subsubtrees = unrolled_map(internal_names) do internal_name + make_subtree_at_name(x, append_internal_name(name, internal_name)) + end + return FieldNameTreeNode(name, subsubtrees) +end + +is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name +is_valid_name(name, tree::FieldNameTreeNode) = + name == tree.name || + is_child_name(name, tree.name) && + unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees) + +function child_names(name, tree) + subtree = get_subtree_at_name(name, tree) + subtree isa FieldNameTreeNode || + error("FieldNameTree does not contain any child names for $name") + return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees) +end +get_subtree_at_name(name, tree::FieldNameTreeLeaf) = + name == tree.name ? tree : + error("FieldNameTree does not contain the name $name") +get_subtree_at_name(name, tree::FieldNameTreeNode) = + if name == tree.name + tree + elseif is_valid_name(name, tree) + subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree + is_child_name(name, subtree.name) + end + get_subtree_at_name(name, subtree_that_contains_name) + else + error("FieldNameTree does not contain the name $name") + end + +################################################################################ + +# This is required for type-stability as of Julia 1.9. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(has_field) + m.recursion_relation = dont_limit + end + for m in methods(get_field) + m.recursion_relation = dont_limit + end + for m in methods(broadcasted_has_field) + m.recursion_relation = dont_limit + end + for m in methods(broadcasted_get_field) + m.recursion_relation = dont_limit + end + for m in methods(wrapped_prop_names) + m.recursion_relation = dont_limit + end + for m in methods(make_subtree_at_name) + m.recursion_relation = dont_limit + end + for m in methods(is_valid_name) + m.recursion_relation = dont_limit + end + for m in methods(get_subtree_at_name) + m.recursion_relation = dont_limit + end +end diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl new file mode 100644 index 0000000000..485053daa9 --- /dev/null +++ b/src/MatrixFields/field_name_dict.jl @@ -0,0 +1,471 @@ +""" + FieldNameDict{T1, T2}(keys, entries) + FieldNameDict{T1, T2}(key_entry_pairs...) + +An `AbstractDict` that contains keys of type `T1` and entries of type `T2`, +where the keys are stored as a `FieldNameSet{T1}`. There are four commonly used +subtypes of `FieldNameDict`: +- `FieldMatrix`, which maps a set of `FieldMatrixKeys` to either + `ColumnwiseBandMatrixField`s or multiples of `LinearAlgebra.I`; this is the + only user-facing subtype of `FieldNameDict` +- `FieldVectorView`, which maps a set of `FieldVectorKeys` to `Field`s; this + subtype is automatically generated when a `FieldVector` is used in the same + operation as a `FieldMatrix` (e.g., when both appear in the same broadcast + expression or are passed to a `FieldMatrixSolver`) +- `FieldMatrixBroadcasted` and `FieldVectorViewBroadcasted`, which are the same + as `FieldMatrix` and `FieldVectorView`, except that they can also store + unevaluated broadcast expressions; these subtypes are automatically generated + when a `FieldMatrix` or a `FieldVectorView` is used in a broadcast expression + +The entry at a specific key can be extracted by calling `dict[key]`, and the +entries that correspond to all the keys in a `FieldNameSet` can be extracted by +calling `dict[set]`. If `dict` is a `FieldMatrix`, the corresponding identity +matrix can be computed by calling `one(dict)`. + +When broadcasting over `FieldNameDict`s, the following operations are supported: +- Addition and subtraction +- Multiplication, where the first argument must be a `FieldMatrix` (or + `FieldMatrixBroadcasted`) +- Inversion, where the argument must be a diagonal `FieldMatrix` (or + `FieldMatrixBroadcasted`), i.e., one in which every entry is either a + `ColumnwiseBandMatrixField` of `DiagonalMatrixRow`s or a multiple of + `LinearAlgebra.I` +""" +struct FieldNameDict{T1, T2, K <: FieldNameSet{T1}, E <: NTuple{<:Any, T2}} <: + AbstractDict{T1, T2} + keys::K + entries::E + + # This needs to be an inner constructor to prevent Julia from automatically + # generating a constructor that fails Aqua.detect_unbound_args_recursively. + FieldNameDict{T1, T2}( + keys::FieldNameSet{T1, <:NTuple{N, T1}}, + entries::NTuple{N, T2}, + ) where {T1, T2, N} = + new{T1, T2, typeof(keys), typeof(entries)}(keys, entries) +end +FieldNameDict{T1, T2}(key_entry_pairs::Pair{<:T1, <:T2}...) where {T1, T2} = + FieldNameDict{T1, T2}( + FieldNameSet{T1}(unrolled_map(pair -> pair[1], key_entry_pairs)), + unrolled_map(pair -> pair[2], key_entry_pairs), + ) +FieldNameDict{T1}(args...) where {T1} = FieldNameDict{T1, Any}(args...) + +const FieldVectorView = FieldNameDict{FieldName, Fields.Field} +const FieldVectorViewBroadcasted = + FieldNameDict{FieldName, Union{Fields.Field, Base.AbstractBroadcasted}} +const FieldMatrix = FieldNameDict{ + FieldNamePair, + Union{UniformScaling, ColumnwiseBandMatrixField}, +} +const FieldMatrixBroadcasted = FieldNameDict{ + FieldNamePair, + Union{UniformScaling, ColumnwiseBandMatrixField, Base.AbstractBroadcasted}, +} + +dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2} + +function Base.show(io::IO, dict::FieldNameDict) + strings = map((key, value) -> " $key => $value", pairs(dict)) + print(io, "$(dict_type(dict))($(join(strings, ",\n")))") +end + +Base.keys(dict::FieldNameDict) = dict.keys + +Base.values(dict::FieldNameDict) = dict.entries + +Base.pairs(dict::FieldNameDict) = + unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup + key_entry_tup[1] => key_entry_tup[2] + end + +Base.length(dict::FieldNameDict) = length(keys(dict)) + +Base.iterate(dict::FieldNameDict, index = 1) = iterate(pairs(dict), index) + +Base.:(==)(dict1::FieldNameDict, dict2::FieldNameDict) = + keys(dict1) == keys(dict2) && values(dict1) == values(dict2) + +function Base.getindex(dict::FieldNameDict, key) + key in keys(dict) || throw(KeyError(key)) + key′, entry′ = + unrolled_findonly(pair -> is_child_value(key, pair[1]), pairs(dict)) + return get_internal_entry(entry′, get_internal_key(key, key′)) +end + +get_internal_key(name1::FieldName, name2::FieldName) = + extract_internal_name(name1, name2) +get_internal_key(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = ( + extract_internal_name(name_pair1[1], name_pair2[1]), + extract_internal_name(name_pair1[2], name_pair2[2]), +) + +unsupported_internal_entry_error(::T, key) where {T} = + error("Unsupported call to get_internal_entry(<$(T.name.name)>, $key)") + +get_internal_entry(entry, name::FieldName) = get_field(entry, name) +get_internal_entry(entry, name_pair::FieldNamePair) = + name_pair == (@name(), @name()) ? entry : + unsupported_internal_entry_error(entry, name_pair) +get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair) = + name_pair[1] == name_pair[2] ? entry : + unsupported_internal_entry_error(entry, name_pair) +function get_internal_entry( + entry::ColumnwiseBandMatrixField, + name_pair::FieldNamePair, +) + # Ensure compatibility with RecursiveApply (i.e., with rmul). + # See note above matrix_product_keys in field_name_set.jl for more details. + T = eltype(eltype(entry)) + if name_pair == (@name(), @name()) + # multiplication case 1, either argument + entry + elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name() + # multiplication case 2, second argument + Base.broadcasted(entry) do matrix_row + map(matrix_row) do matrix_row_entry + broadcasted_get_field(matrix_row_entry, name_pair[1]) + end + end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. + elseif T <: SingleValue && name_pair[1] == name_pair[2] + # multiplication case 3, first argument + entry + else + unsupported_internal_entry_error(entry, name_pair) + end +end + +# Similar behavior to indexing an array with a slice. +function Base.getindex(dict::FieldNameDict, new_keys::FieldNameSet) + FieldNameDictType = dict_type(dict) + common_keys = intersect(keys(dict), new_keys) + return FieldNameDictType(common_keys, map(key -> dict[key], common_keys)) +end + +function Base.similar(dict::FieldNameDict) + FieldNameDictType = dict_type(dict) + entries = unrolled_map(values(dict)) do entry + entry isa UniformScaling ? entry : similar(entry) + end + return FieldNameDictType(keys(dict), entries) +end + +# Note: This assumes that the matrix has the same row and column units, since I +# cannot be multiplied by anything other than a scalar. +function Base.one(matrix::FieldMatrix) + diagonal_keys = matrix_diagonal_keys(keys(matrix)) + return FieldMatrix(diagonal_keys, map(_ -> I, diagonal_keys)) +end + +function check_block_diagonal_matrix(matrix, error_message_start = "The matrix") + off_diagonal_keys = matrix_off_diagonal_keys(keys(matrix)) + isempty(off_diagonal_keys) || error( + "$error_message_start has entries at the following off-diagonal keys: \ + $(set_string(off_diagonal_keys))", + ) +end + +function check_diagonal_matrix(matrix, error_message_start = "The matrix") + check_block_diagonal_matrix(matrix, error_message_start) + non_diagonal_entry_pairs = unrolled_filter(pairs(matrix)) do pair + !( + pair[2] isa UniformScaling || + pair[2] isa ColumnwiseBandMatrixField && + eltype(pair[2]) <: DiagonalMatrixRow || + pair[2] isa Base.AbstractBroadcasted && + eltype(pair[2]) <: DiagonalMatrixRow + ) + end + non_diagonal_entry_keys = + FieldMatrixKeys(unrolled_map(pair -> pair[1], non_diagonal_entry_pairs)) + isempty(non_diagonal_entry_keys) || error( + "$error_message_start has non-diagonal entries at the following keys: \ + $(set_string(non_diagonal_entry_keys))", + ) +end + +""" + field_vector_view(x) + +Constructs a `FieldVectorView` that contains all the top-level `Field`s in the +`FieldVector` `x`. +""" +function field_vector_view(x) + top_level_keys = FieldVectorKeys(top_level_names(x), FieldNameTree(x)) + entries = map(name -> get_field(x, name), top_level_keys) + return FieldVectorView(top_level_keys, entries) +end + +################################################################################ + +struct FieldMatrixStyle <: Base.Broadcast.BroadcastStyle end + +const FieldMatrixStyleType = + Union{FieldVectorViewBroadcasted, FieldMatrixBroadcasted} + +const FieldVectorStyleType = Union{ + Fields.FieldVector, + Base.Broadcast.Broadcasted{<:Fields.FieldVectorStyle}, +} + +Base.Broadcast.broadcastable(vector_or_matrix::FieldMatrixStyleType) = + vector_or_matrix +Base.Broadcast.broadcastable(vector::FieldVectorView) = + FieldVectorViewBroadcasted(keys(vector), values(vector)) +Base.Broadcast.broadcastable(matrix::FieldMatrix) = + FieldMatrixBroadcasted(keys(matrix), values(matrix)) + +Base.Broadcast.BroadcastStyle(::Type{<:FieldMatrixStyleType}) = + FieldMatrixStyle() +Base.Broadcast.BroadcastStyle(::FieldMatrixStyle, ::Fields.FieldVectorStyle) = + FieldMatrixStyle() + +function field_matrix_broadcast_error(f, args...) + arg_string(::FieldVectorViewBroadcasted) = "" + arg_string(::FieldMatrixBroadcasted) = "" + arg_string(::FieldVectorStyleType) = "" + arg_string(::T) where {T} = error( + "Unsupported FieldMatrixStyle broadcast argument type: $(T.name.name)", + ) + args_string = join(map(arg_string, args), ", ") + error("Unsupported FieldMatrixStyle broadcast operation: $f.($args_string)") +end + +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::F, # This should be restricted to a Function to avoid a method ambiguity. + args..., +) where {F <: Function} = field_matrix_broadcast_error(f, args...) + +# When a broadcast expression with + or * has more than two arguments, split it +# up into a chain of two-argument broadcast expressions. This simplifies the +# remaining methods for Base.Broadcast.broadcasted, since it allows us to assume +# that they will have at most two arguments. +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::Union{typeof(+), typeof(*)}, + arg1, + arg2, + arg3, + args..., +) = + unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ + Base.Broadcast.broadcasted(f, arg1′, arg2′) + end + +# Add support for broadcast expressions of the form dict1 .= dict2. +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + ::typeof(identity), + arg::FieldMatrixStyleType, +) = arg + +function Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + ::typeof(-), + vector_or_matrix::FieldMatrixStyleType, +) + FieldNameDictType = dict_type(vector_or_matrix) + entries = unrolled_map(values(vector_or_matrix)) do entry + entry isa UniformScaling ? -entry : Base.Broadcast.broadcasted(-, entry) + end + return FieldNameDictType(keys(vector_or_matrix), entries) +end + +function Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::Union{typeof(+), typeof(-)}, + vector_or_matrix1::FieldMatrixStyleType, + vector_or_matrix2::FieldMatrixStyleType, +) + dict_type(vector_or_matrix1) == dict_type(vector_or_matrix2) || + field_matrix_broadcast_error(f, vector_or_matrix1, vector_or_matrix2) + FieldNameDictType = dict_type(vector_or_matrix1) + all_keys = union(keys(vector_or_matrix1), keys(vector_or_matrix2)) + entries = map(all_keys) do key + if key in intersect(keys(vector_or_matrix1), keys(vector_or_matrix2)) + entry1 = vector_or_matrix1[key] + entry2 = vector_or_matrix2[key] + if entry1 isa UniformScaling && entry2 isa UniformScaling + f(entry1, entry2) + elseif entry1 isa UniformScaling + Base.Broadcast.broadcasted(f, (entry1,), entry2) + elseif entry2 isa UniformScaling + Base.Broadcast.broadcasted(f, entry1, (entry2,)) + else + Base.Broadcast.broadcasted(f, entry1, entry2) + end + elseif key in keys(vector_or_matrix1) + vector_or_matrix1[key] + else + if f isa typeof(+) + vector_or_matrix2[key] + else + entry = vector_or_matrix2[key] + entry isa UniformScaling ? -entry : + Base.Broadcast.broadcasted(-, entry) + end + end + end + return FieldNameDictType(all_keys, entries) +end + +function Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + ::typeof(*), + matrix::FieldMatrixBroadcasted, + vector_or_matrix::FieldMatrixStyleType, +) + FieldNameDictType = dict_type(vector_or_matrix) + product_keys = matrix_product_keys(keys(matrix), keys(vector_or_matrix)) + entries = map(product_keys) do product_key + summand_names = summand_names_for_matrix_product( + product_key, + keys(matrix), + keys(vector_or_matrix), + ) + summand_bcs = map(summand_names) do summand_name + key1, key2 = matrix_product_argument_keys(product_key, summand_name) + entry1 = matrix[key1] + entry2 = vector_or_matrix[key2] + if entry1 isa UniformScaling && entry2 isa UniformScaling + entry1 * entry2 + elseif entry1 isa UniformScaling + Base.Broadcast.broadcasted(*, entry1.λ, entry2) + elseif entry2 isa UniformScaling + Base.Broadcast.broadcasted(*, entry1, entry2.λ) + else + Base.Broadcast.broadcasted(⋅, entry1, entry2) + end + end + length(summand_bcs) == 1 ? summand_bcs[1] : + Base.Broadcast.broadcasted(+, summand_bcs...) + end + return FieldNameDictType(product_keys, entries) +end + +matrix_product_argument_keys(product_name::FieldName, summand_name) = + ((product_name, summand_name), summand_name) +matrix_product_argument_keys(product_name_pair::FieldNamePair, summand_name) = + ((product_name_pair[1], summand_name), (summand_name, product_name_pair[2])) + +function Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + ::typeof(inv), + matrix::FieldMatrixBroadcasted, +) + check_diagonal_matrix( + matrix, + "inv.() cannot be computed because the matrix", + ) + entries = unrolled_map(values(matrix)) do entry + entry isa UniformScaling ? inv(entry) : + Base.Broadcast.broadcasted(inv, entry) + end + return FieldMatrixBroadcasted(keys(matrix), entries) +end + +# Convert every FieldVectorStyle object to a FieldMatrixStyle object. This makes +# it possible to directly use a FieldVector in the same broadcast expression as +# a FieldMatrix, without needing to convert it to a FieldVectorView first. +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::F, + arg::FieldVectorStyleType, +) where {F <: Function} = + Base.Broadcast.broadcasted(f, convert_to_field_matrix_style(arg)) +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::F, + arg1::FieldVectorStyleType, + arg2, +) where {F <: Function} = + Base.Broadcast.broadcasted(f, convert_to_field_matrix_style(arg1), arg2) +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::F, + arg1, + arg2::FieldVectorStyleType, +) where {F <: Function} = + Base.Broadcast.broadcasted(f, arg1, convert_to_field_matrix_style(arg2)) +Base.Broadcast.broadcasted( + ::FieldMatrixStyle, + f::F, + arg1::FieldVectorStyleType, + arg2::FieldVectorStyleType, +) where {F <: Function} = Base.Broadcast.broadcasted( + f, + convert_to_field_matrix_style(arg1), + convert_to_field_matrix_style(arg2), +) + +convert_to_field_matrix_style(x::Fields.FieldVector) = field_vector_view(x) +convert_to_field_matrix_style( + bc::Base.Broadcast.Broadcasted{<:Fields.FieldVectorStyle}, +) = Base.broadcast.broadcasted(FieldMatrixStyle(), bc.f, bc.args...) + +################################################################################ + +materialized_dict_type(::FieldVectorViewBroadcasted) = FieldVectorView +materialized_dict_type(::FieldMatrixBroadcasted) = FieldMatrix + +function Base.Broadcast.materialize(vector_or_matrix::FieldMatrixStyleType) + FieldNameDictType = materialized_dict_type(vector_or_matrix) + entries = unrolled_map(values(vector_or_matrix)) do entry + Base.Broadcast.materialize(entry) + end + return FieldNameDictType(keys(vector_or_matrix), entries) +end + +Base.Broadcast.materialize!( + dest::Fields.FieldVector, + vector_or_matrix::FieldMatrixStyleType, +) = Base.Broadcast.materialize!(field_vector_view(dest), vector_or_matrix) +function Base.Broadcast.materialize!( + dest::Union{FieldVectorView, FieldMatrix}, + vector_or_matrix::FieldMatrixStyleType, +) + FieldNameDictType = materialized_dict_type(vector_or_matrix) + dest isa FieldNameDictType || + error("Broadcast result and destination types are incompatible: + $FieldNameDictType vs. $(typeof(dest).name.name)") + is_subset_that_covers_set(keys(vector_or_matrix), keys(dest)) || error( + "Broadcast result and destination keys are incompatible: \ + $(set_string(keys(vector_or_matrix))) vs. $(set_string(keys(dest)))", + ) # It is not always the case that keys(vector_or_matrix) == keys(dest). + foreach(keys(vector_or_matrix)) do key + entry = vector_or_matrix[key] + if dest[key] isa UniformScaling + dest[key] == entry || error("UniformScaling is immutable") + elseif entry isa UniformScaling + dest[key] .= (entry,) + else + Base.Broadcast.materialize!(dest[key], entry) + end + end +end + +#= +For debugging, uncomment the function below and put the following lines into the +loop in materialize!: + println() + println(key) + println(summary_string(vector_or_matrix[key])) + println(dest[key]) + println() + +summary_string(entry) = summary_string(entry, 0) +summary_string(entry, indent_level) = "$(" "^indent_level)$entry" +function summary_string(field::Fields.Field, indent_level) + staggering_string = + hasproperty(axes(field), :staggering) ? + string(typeof(axes(field).staggering).name.name) : "Single Level" + return "$(" "^indent_level)Field{$(eltype(field)), $staggering_string}" +end +function summary_string(bc::Base.AbstractBroadcasted, indent_level) + func = bc isa Operators.OperatorBroadcasted ? bc.op : bc.f + arg_strings = map(arg -> summary_string(arg, indent_level + 1), bc.args) + tab = " "^indent_level + return "$(tab)Broadcasted{$func}(\n$(join(arg_strings, ",\n")),\n$tab)" +end +=# diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl new file mode 100644 index 0000000000..80ca539b6a --- /dev/null +++ b/src/MatrixFields/field_name_set.jl @@ -0,0 +1,426 @@ +const FieldNamePair = Tuple{FieldName, FieldName} + +""" + FieldNameSet{T}(values, [name_tree]) + +An `AbstractSet` that contains values of type `T`, which serves as an analogue +of a `KeySet` for a `FieldNameDict`. There are two subtypes of `FieldNameSet`: +- `FieldVectorKeys`, for which `T` is set to `FieldName` +- `FieldMatrixKeys`, for which `T` is set to `Tuple{FieldName, FieldName}`; each + tuple of type `T` represents a pair of row-column indices + +Since `FieldName`s are singleton types, the result of almost any `FieldNameSet` +operation can be inferred during compilation. So, with the exception of `map`, +`foreach`, and `set_string`, functions of `FieldNameSet`s do not have any +performance cost at runtime (as long as their arguments are inferrable). + +Unlike other `AbstractSet`s, `FieldNameSet` has special behavior for overlapping +values. For example, the `FieldName`s `@name(a.b)` and `@name(a.b.c)` overlap, +so any set operation needs to first decompose `@name(a.b)` into its child values +before combining it with `@name(a.b.c)`. In order to support this (and also to +support the ability to compute set complements), `FieldNameSet` stores a +`FieldNameTree` `name_tree`, which it uses to infer child values. If `name_tree` +is not specified, it gets set to `nothing` by default, which causes some +`FieldNameSet` operations to become disabled. For binary operations like `union` +or `setdiff`, only one set needs to specify a `name_tree`; if two sets both +specify a `name_tree`, the `name_tree`s must be identical. +""" +struct FieldNameSet{ + T <: Union{FieldName, FieldNamePair}, + V <: NTuple{<:Any, T}, + N <: Union{FieldNameTree, Nothing}, +} <: AbstractSet{T} + values::V + name_tree::N + + # This needs to be an inner constructor to prevent Julia from automatically + # generating a constructor that fails Aqua.detect_unbound_args_recursively. + function FieldNameSet{T}( + values::NTuple{<:Any, T}, + name_tree::Union{FieldNameTree, Nothing} = nothing, + ) where {T} + check_values(values, name_tree) + return new{T, typeof(values), typeof(name_tree)}(values, name_tree) + end +end + +const FieldVectorKeys = FieldNameSet{FieldName} +const FieldMatrixKeys = FieldNameSet{FieldNamePair} + +function Base.show(io::IO, set::FieldNameSet{T}) where {T} + # Do not print the FieldNameTree, since the current implementation ensures + # that it will be the same across all FieldNameSets that are used together. + name_tree_str = isnothing(set.name_tree) ? "" : "; " + print(io, "$(FieldNameSet{T})($(join(set.values, ", "))$name_tree_str)") +end + +Base.length(set::FieldNameSet) = length(set.values) + +Base.iterate(set::FieldNameSet, index = 1) = iterate(set.values, index) + +Base.map(f::F, set::FieldNameSet) where {F} = unrolled_map(f, set.values) + +Base.foreach(f::F, set::FieldNameSet) where {F} = + unrolled_foreach(f, set.values) + +Base.in(value, set::FieldNameSet) = + is_value_in_set(value, set.values, set.name_tree) + +function Base.issubset(set1::FieldNameSet, set2::FieldNameSet) + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + unrolled_all(set1.values) do value + is_value_in_set(value, set2.values, name_tree) + end +end + +Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = + issubset(set1, set2) && issubset(set2, set1) + +function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + values1′, values2′ = set1.values, set2.values + values1, values2 = non_overlapping_values(values1′, values2′, name_tree) + result_values = unrolled_filter(values2) do value + unrolled_any(isequal(value), values1) + end + return FieldNameSet{T}(result_values, name_tree) +end + +function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + values1′, values2′ = set1.values, set2.values + values1, values2 = non_overlapping_values(values1′, values2′, name_tree) + values2_minus_values1 = unrolled_filter(values2) do value + !unrolled_any(isequal(value), values1) + end + result_values = (values1..., values2_minus_values1...) + return FieldNameSet{T}(result_values, name_tree) +end + +function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + set2_complement_values = set_complement_values(T, set2.values, name_tree) + set2_complement = FieldNameSet{T}(set2_complement_values, name_tree) + return intersect(set1, set2_complement) +end + +set_string(set) = + length(set) == 2 ? join(set.values, " and ") : + join(set.values, ", ", ", and ") + +is_subset_that_covers_set(set1, set2) = + issubset(set1, set2) && isempty(setdiff(set2, set1)) + +function set_complement(set::FieldNameSet{T}) where {T} + result_values = set_complement_values(T, set.values, set.name_tree) + return FieldNameSet{T}(result_values, set.name_tree) +end + +function corresponding_matrix_keys(set::FieldVectorKeys) + result_values = unrolled_map(name -> (name, name), set.values) + return FieldMatrixKeys(result_values, set.name_tree) +end + +function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + result_values = unrolled_mapflatten(set1.values) do row_name + unrolled_map(col_name -> (row_name, col_name), set2.values) + end + return FieldMatrixKeys(result_values, name_tree) +end + +function matrix_row_keys(set::FieldMatrixKeys) + result_values′ = unrolled_map(name_pair -> name_pair[1], set.values) + result_values = + unique_and_non_overlapping_values(result_values′, set.name_tree) + return FieldVectorKeys(result_values, set.name_tree) +end + +function matrix_off_diagonal_keys(set::FieldMatrixKeys) + result_values = + unrolled_filter(name_pair -> name_pair[1] != name_pair[2], set.values) + return FieldMatrixKeys(result_values, set.name_tree) +end + +function matrix_diagonal_keys(set::FieldMatrixKeys) + result_values′ = unrolled_filter(set.values) do name_pair + names_are_overlapping(name_pair[1], name_pair[2]) + end + result_values = unrolled_map(result_values′) do name_pair + name_pair[1] == name_pair[2] ? name_pair : + is_child_value(name_pair[1], name_pair[2]) ? + (name_pair[1], name_pair[1]) : (name_pair[2], name_pair[2]) + end + return FieldMatrixKeys(result_values, set.name_tree) +end + +#= +There are three cases that we need to support in order to be compatible with +RecursiveApply (i.e., with rmul): +1. (_, name) * name or + (_, name) * (name, _) +2. (_, name_child) * name -> (_, name_child) * name_child or + (_, name_child) * (name, _) -> (_, name_child) * (name_child, _) + We are able to support this by extracting internal rows from FieldNameDict + entries. We can only extract an internal row from a ColumnwiseBandMatrixField + whose values contain internal values that correspond to "name_child". +3. (name, name) * name_child -> (name_child, name_child) * name_child or + (name, name) * (name_child, _) -> (name_child, name_child) * (name_child, _) + We are able to support this by extracting internal diagonal blocks from + FieldNameDict entries. We can only extract an internal diagonal block from a + LinearAlgebra.UniformScaling or a ColumnwiseBandMatrixField of SingleValues. +We only need to support diagonal matrix blocks of scalar values in the last case +because we cannot extract internal columns from FieldNameDict entries. +=# +function matrix_product_keys(set1::FieldMatrixKeys, set2::FieldNameSet) + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + result_values′ = unrolled_mapflatten(set1.values) do name_pair1 + overlapping_set2_values = unrolled_filter(set2.values) do value2 + row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] + names_are_overlapping(name_pair1[2], row_name2) + end + unrolled_map(overlapping_set2_values) do value2 + row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] + if is_child_name(name_pair1[2], row_name2) + # multiplication cases 1 and 2 + eltype(set2) <: FieldName ? name_pair1[1] : + (name_pair1[1], value2[2]) + elseif name_pair1[1] == name_pair1[2] + # multiplication case 3 + value2 + else + error("Cannot extract internal column from an off-diagonal key") + end + end + end + result_values = unique_and_non_overlapping_values(result_values′, name_tree) + return FieldNameSet{eltype(set2)}(result_values, name_tree) +end +function summand_names_for_matrix_product( + product_key, + set1::FieldMatrixKeys, + set2::FieldNameSet, +) + product_row_name = eltype(set2) <: FieldName ? product_key : product_key[1] + name_tree = combine_name_trees(set1.name_tree, set2.name_tree) + overlapping_set1_values = unrolled_filter(set1.values) do name_pair1 + names_are_overlapping(product_row_name, name_pair1[1]) + end + result_values = unrolled_mapflatten(overlapping_set1_values) do name_pair1 + overlapping_set2_values = unrolled_filter(set2.values) do value2 + row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] + names_are_overlapping(name_pair1[2], row_name2) && ( + eltype(set2) <: FieldName || + names_are_overlapping(product_key[2], value2[2]) + ) + end + unrolled_map(overlapping_set2_values) do value2 + row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] + is_child_name(product_row_name, name_pair1[1]) && ( + eltype(set2) <: FieldName || product_key[2] == value2[2] + ) || error("Invalid matrix product key $product_key") + if is_child_name(name_pair1[2], row_name2) + # multiplication cases 1 and 2 + name_pair1[2] + else + # multiplication case 3 + product_row_name == row_name2 && + name_pair1[1] == name_pair1[2] || + error("Invalid matrix product key $product_key") + row_name2 + end + end + end + return FieldVectorKeys(result_values, name_tree) +end + +################################################################################ + +# Internal functions: + +check_values(values, name_tree) = + unrolled_foreach(values) do value + (isnothing(name_tree) || is_valid_value(value, name_tree)) || error( + "Invalid FieldNameSet value: $value is incompatible with name_tree", + ) + duplicate_values = unrolled_filter(isequal(value), values) + length(duplicate_values) == 1 || error( + "Duplicate FieldNameSet values: $(length(duplicate_values)) copies \ + of $value have been passed to a FieldNameSet constructor", + ) + overlapping_values = unrolled_filter(values) do value′ + value != value′ && values_are_overlapping(value, value′) + end + if !isempty(overlapping_values) + overlapping_values_string = + length(overlapping_values) == 2 ? + join(overlapping_values, " or ") : + join(overlapping_values, ", ", ", or ") + error("Overlapping FieldNameSet values: $value cannot be in the \ + same FieldNameSet as $overlapping_values_string") + end + end + +combine_name_trees(::Nothing, ::Nothing) = nothing +combine_name_trees(name_tree1, ::Nothing) = name_tree1 +combine_name_trees(::Nothing, name_tree2) = name_tree2 +combine_name_trees(name_tree1, name_tree2) = + name_tree1 == name_tree2 ? name_tree1 : + error("Mismatched FieldNameTrees: The ability to combine different \ + FieldNameTrees has not been implemented") + +is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) +is_valid_value(name_pair::FieldNamePair, name_tree) = + is_valid_name(name_pair[1], name_tree) && + is_valid_name(name_pair[2], name_tree) + +values_are_overlapping(name1::FieldName, name2::FieldName) = + names_are_overlapping(name1, name2) +values_are_overlapping(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = + names_are_overlapping(name_pair1[1], name_pair2[1]) && + names_are_overlapping(name_pair1[2], name_pair2[2]) + +is_child_value(name1::FieldName, name2::FieldName) = is_child_name(name1, name2) +is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = + is_child_name(name_pair1[1], name_pair2[1]) && + is_child_name(name_pair1[2], name_pair2[2]) + +is_value_in_set(value, values, name_tree) = + if unrolled_any(isequal(value), values) + true + elseif unrolled_any(value′ -> is_child_value(value, value′), values) + isnothing(name_tree) && error( + "Cannot check if $value is in FieldNameSet without a FieldNameTree", + ) + is_valid_value(value, name_tree) + else + false + end + +function non_overlapping_values(values1, values2, name_tree) + new_values1 = unrolled_mapflatten(values1) do value + value_or_non_overlapping_children(value, values2, name_tree) + end + new_values2 = unrolled_mapflatten(values2) do value + value_or_non_overlapping_children(value, values1, name_tree) + end + if eltype(values1) <: FieldName + new_values1, new_values2 + else + # Repeat the above operation to handle complex matrix key overlaps. + new_values1′ = unrolled_mapflatten(new_values1) do value + value_or_non_overlapping_children(value, new_values2, name_tree) + end + new_values2′ = unrolled_mapflatten(new_values2) do value + value_or_non_overlapping_children(value, new_values1, name_tree) + end + return new_values1′, new_values2′ + end +end + +function unique_and_non_overlapping_values(values, name_tree) + new_values = unrolled_mapflatten(values) do value + value_or_non_overlapping_children(value, values, name_tree) + end + return unrolled_unique(new_values) +end + +function value_or_non_overlapping_children(name::FieldName, names, name_tree) + need_child_names = unrolled_any(names) do name′ + is_child_value(name′, name) && name′ != name + end + need_child_names || return (name,) + isnothing(name_tree) && + error("Cannot compute child names of $name without a FieldNameTree") + return unrolled_mapflatten(child_names(name, name_tree)) do child_name + value_or_non_overlapping_children(child_name, names, name_tree) + end +end +function value_or_non_overlapping_children( + name_pair::FieldNamePair, + name_pairs, + name_tree, +) + need_row_child_names = unrolled_any(name_pairs) do name_pair′ + is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] + end + need_col_child_names = unrolled_any(name_pairs) do name_pair′ + is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] + end + need_row_child_names || need_col_child_names || return (name_pair,) + isnothing(name_tree) && error( + "Cannot compute child name pairs of $name_pair without a FieldNameTree", + ) + row_name_children = + need_row_child_names ? child_names(name_pair[1], name_tree) : + (name_pair[1],) + col_name_children = + need_col_child_names ? child_names(name_pair[2], name_tree) : + (name_pair[2],) + return unrolled_mapflatten(row_name_children) do row_name_child + unrolled_mapflatten(col_name_children) do col_name_child + child_pair = (row_name_child, col_name_child) + value_or_non_overlapping_children(child_pair, name_pairs, name_tree) + end + end +end + +set_complement_values(_, _, ::Nothing) = + error("Cannot compute complement of a FieldNameSet without a FieldNameTree") +set_complement_values(::Type{<:FieldName}, names, name_tree::FieldNameTree) = + complement_values_in_subtree(names, name_tree) +set_complement_values( + ::Type{<:FieldNamePair}, + name_pairs, + name_tree::FieldNameTree, +) = complement_values_in_subtree_pair(name_pairs, (name_tree, name_tree)) + +function complement_values_in_subtree(names, subtree) + name = subtree.name + unrolled_all(name′ -> !is_child_value(name, name′), names) || return () + unrolled_any(name′ -> is_child_value(name′, name), names) || return (name,) + return unrolled_mapflatten(subtree.subtrees) do subsubtree + complement_values_in_subtree(names, subsubtree) + end +end + +function complement_values_in_subtree_pair(name_pairs, subtree_pair) + name_pair = (subtree_pair[1].name, subtree_pair[2].name) + is_name_pair_in_complement = unrolled_all(name_pairs) do name_pair′ + !is_child_value(name_pair, name_pair′) + end + is_name_pair_in_complement || return () + need_row_subsubtrees = unrolled_any(name_pairs) do name_pair′ + is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] + end + need_col_subsubtrees = unrolled_any(name_pairs) do name_pair′ + is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] + end + need_row_subsubtrees || need_col_subsubtrees || return (name_pair,) + row_subsubtrees = + need_row_subsubtrees ? subtree_pair[1].subtrees : (subtree_pair[1],) + col_subsubtrees = + need_col_subsubtrees ? subtree_pair[2].subtrees : (subtree_pair[2],) + return unrolled_mapflatten(row_subsubtrees) do row_subsubtree + unrolled_mapflatten(col_subsubtrees) do col_subsubtree + subsubtree_pair = (row_subsubtree, col_subsubtree) + complement_values_in_subtree_pair(name_pairs, subsubtree_pair) + end + end +end + +################################################################################ + +# This is required for type-stability as of Julia 1.9. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(value_or_non_overlapping_children) + m.recursion_relation = dont_limit + end + for m in methods(complement_values_in_subtree) + m.recursion_relation = dont_limit + end + for m in methods(complement_values_in_subtree_pair) + m.recursion_relation = dont_limit + end +end diff --git a/src/MatrixFields/operator_matrices.jl b/src/MatrixFields/operator_matrices.jl index 2b6e99b005..a052c0cf2f 100644 --- a/src/MatrixFields/operator_matrices.jl +++ b/src/MatrixFields/operator_matrices.jl @@ -784,7 +784,7 @@ op_matrix_row_type(op::Operators.DivergenceOperator, ::Type{FT}) where {FT} = uses_extrapolate(op) ? QuaddiagonalMatrixRow{Adjoint{FT, C3{FT}}} : BidiagonalMatrixRow{Adjoint{FT, C3{FT}}} Base.@propagate_inbounds function op_matrix_interior_row( - op::Operators.DivergenceOperator, + ::Operators.DivergenceOperator, space, loc, idx, diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl new file mode 100644 index 0000000000..bd02d67c2d --- /dev/null +++ b/src/MatrixFields/single_field_solver.jl @@ -0,0 +1,211 @@ +dual_type(::Type{A}) where {A} = typeof(Geometry.dual(A.instance)) + +inv_return_type(::Type{X}) where {X} = error( + "Cannot solve linear system because a diagonal entry in A contains the \ + non-invertible type $X", +) +inv_return_type(::Type{X}) where {X <: Union{Number, SMatrix}} = X +inv_return_type(::Type{X}) where {T, X <: Geometry.Axis2TensorOrAdj{T}} = + axis_tensor_type(T, Tuple{dual_type(axis2(X)), dual_type(axis1(X))}) + +x_eltype(A::UniformScaling, b) = x_eltype(eltype(A), eltype(b)) +x_eltype(A::ColumnwiseBandMatrixField, b) = + x_eltype(eltype(eltype(A)), eltype(b)) +x_eltype(::Type{T_A}, ::Type{T_b}) where {T_A, T_b} = + rmul_return_type(inv_return_type(T_A), T_b) + +unit_eltype(A::UniformScaling) = unit_eltype(eltype(A)) +unit_eltype(A::ColumnwiseBandMatrixField) = unit_eltype(eltype(eltype(A))) +unit_eltype(::Type{T_A}) where {T_A} = + rmul_return_type(inv_return_type(T_A), T_A) + +################################################################################ + +check_single_field_solver(::UniformScaling, _) = nothing +function check_single_field_solver(A::ColumnwiseBandMatrixField, b) + matrix_shape(A) == Square() || error( + "Cannot solve linear system because a diagonal entry in A is not a \ + square matrix", + ) + axes(A) === axes(b) || error( + "Cannot solve linear system because a diagonal entry in A is not on \ + the same space as the corresponding entry in b", + ) +end + +single_field_solver_cache(::UniformScaling, _) = nothing +function single_field_solver_cache(A::ColumnwiseBandMatrixField, b) + ud = outer_diagonals(eltype(A))[2] + cache_eltype = + ud == 0 ? Tuple{} : + Tuple{x_eltype(A, b), ntuple(_ -> unit_eltype(A), Val(ud))...} + return similar(A, cache_eltype) +end + +single_field_solve!(_, x, A::UniformScaling, b) = x .= inv(A.λ) .* b +single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) = + single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b) + +single_field_solve!(::ClimaComms.AbstractCPUDevice, cache, x, A, b) = + _single_field_solve!(cache, x, A, b) +function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b) + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + nthreads, nblocks = Spaces._configure_threadblock(Ni * Nj * Nh) + CUDA.@cuda threads = nthreads blocks = nblocks single_field_solve_kernel!( + cache, + x, + A, + b, + ) +end + +function single_field_solve_kernel!(cache, x, A, b) + idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + if idx <= Ni * Nj * Nh + i, j, h = Spaces._get_idx((Ni, Nj, Nh), idx) + _single_field_solve!( + Spaces.column(cache, i, j, h), + Spaces.column(x, i, j, h), + Spaces.column(A, i, j, h), + Spaces.column(b, i, j, h), + ) + end + return nothing +end +single_field_solve_kernel!( + cache::Fields.ColumnField, + x::Fields.ColumnField, + A::Fields.ColumnField, + b::Fields.ColumnField, +) = _single_field_solve!(cache, x, A, b) + +_single_field_solve!(cache, x, A, b) = + Fields.bycolumn(axes(A)) do colidx + _single_field_solve!(cache[colidx], x[colidx], A[colidx], b[colidx]) + end +_single_field_solve!( + cache::Fields.ColumnField, + x::Fields.ColumnField, + A::Fields.ColumnField, + b::Fields.ColumnField, +) = band_matrix_solve!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), +) + +unzip_tuple_field_values(data) = + ntuple(i -> data.:($i), Val(length(propertynames(data)))) + +function band_matrix_solve!(::Type{<:DiagonalMatrixRow}, _, x, Aⱼs, b) + (A₀,) = Aⱼs + n = length(x) + @inbounds for i in 1:n + x[i] = inv(A₀[i]) ⊠ b[i] + end +end + +#= +The Thomas algorithm, as presented in + https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm#Method, +but with the following variable name changes: + - a → A₋₁ + - b → A₀ + - c → A₊₁ + - d → b + - c′ → U₊₁ + - d′ → Ux +Transforms the tri-diagonal matrix into a unit upper bi-diagonal matrix, then +solves the resulting system using back substitution. The order of +multiplications has been modified in order to handle block vectors/matrices. +=# +function band_matrix_solve!(::Type{<:TridiagonalMatrixRow}, cache, x, Aⱼs, b) + A₋₁, A₀, A₊₁ = Aⱼs + Ux, U₊₁ = cache + n = length(x) + @inbounds begin + inv_D₀ = inv(A₀[1]) + Ux[1] = inv_D₀ ⊠ b[1] + U₊₁[1] = inv_D₀ ⊠ A₊₁[1] + + for i in 2:n + inv_D₀ = inv(A₀[i] ⊟ A₋₁[i] ⊠ U₊₁[i - 1]) + Ux[i] = inv_D₀ ⊠ (b[i] ⊟ A₋₁[i] ⊠ Ux[i - 1]) + i < n && (U₊₁[i] = inv_D₀ ⊠ A₊₁[i]) # U₊₁[n] is outside the matrix. + end + + x[n] = Ux[n] + for i in (n - 1):-1:1 + x[i] = Ux[i] ⊟ U₊₁[i] ⊠ x[i + 1] + end + end +end + +#= +The PTRANS-I algorithm, as presented in + https://www.hindawi.com/journals/mpe/2015/232456/alg1, +but with the following variable name changes: + - e → A₋₂ + - c → A₋₁ + - d → A₀ + - a → A₊₁ + - b → A₊₂ + - y → b + - α → U₊₁ + - β → U₊₂ + - z → Ux + - γ → L₋₁ + - μ → D₀ +Transforms the penta-diagonal matrix into a unit upper tri-diagonal matrix, then +solves the resulting system using back substitution. The order of +multiplications has been modified in order to handle block vectors/matrices. +=# +function band_matrix_solve!(::Type{<:PentadiagonalMatrixRow}, cache, x, Aⱼs, b) + A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs + Ux, U₊₁, U₊₂ = cache + n = length(x) + @inbounds begin + inv_D₀ = inv(A₀[1]) + Ux[1] = inv_D₀ ⊠ b[1] + U₊₁[1] = inv_D₀ ⊠ A₊₁[1] + U₊₂[1] = inv_D₀ ⊠ A₊₂[1] + + inv_D₀ = inv(A₀[2] ⊟ A₋₁[2] ⊠ U₊₁[1]) + Ux[2] = inv_D₀ ⊠ (b[2] ⊟ A₋₁[2] ⊠ Ux[1]) + U₊₁[2] = inv_D₀ ⊠ (A₊₁[2] ⊟ A₋₁[2] ⊠ U₊₂[1]) + U₊₂[2] = inv_D₀ ⊠ A₊₂[2] + + for i in 3:n + L₋₁ = A₋₁[i] ⊟ A₋₂[i] ⊠ U₊₁[i - 2] + inv_D₀ = inv(A₀[i] ⊟ L₋₁ ⊠ U₊₁[i - 1] ⊟ A₋₂[i] ⊠ U₊₂[i - 2]) + Ux[i] = inv_D₀ ⊠ (b[i] ⊟ L₋₁ ⊠ Ux[i - 1] ⊟ A₋₂[i] ⊠ Ux[i - 2]) + i < n && (U₊₁[i] = inv_D₀ ⊠ (A₊₁[i] ⊟ L₋₁ ⊠ U₊₂[i - 1])) + i < n - 1 && (U₊₂[i] = inv_D₀ ⊠ A₊₂[i]) + end + + x[n] = Ux[n] + x[n - 1] = Ux[n - 1] ⊟ U₊₁[n - 1] ⊠ x[n] + for i in (n - 2):-1:1 + x[i] = Ux[i] ⊟ U₊₁[i] ⊠ x[i + 1] ⊟ U₊₂[i] ⊠ x[i + 2] + end + end +end + +#= +Each method for band_matrix_solve! above has an order of operations that is +correct when x, A, and b are block vectors/matrices (i.e., when multiplication +is not necessarily commutative). So, the following are all valid combinations of +eltype(x), eltype(A), and eltype(b): +- Number, Number, and Number +- SVector{N}, SMatrix{N, N}, and SVector{N} +- AxisVector with axis A1, Axis2TensorOrAdj with axes (A2, dual(A1)), and + AxisVector with axis A2 +- nested type (Tuple or NamedTuple), scalar type (Number, SMatrix, or + Axis2TensorOrAdj), nested type (Tuple or NamedTuple) + +We might eventually want a single general method for band_matrix_solve!, similar +to the BLAS.gbsv function. For now, though, the methods above should be enough. +=# diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl new file mode 100644 index 0000000000..947a45084d --- /dev/null +++ b/src/MatrixFields/unrolled_functions.jl @@ -0,0 +1,91 @@ +@inline unrolled_zip(values1, values2) = + isempty(values1) || isempty(values2) ? () : + ( + (first(values1), first(values2)), + unrolled_zip(Base.tail(values1), Base.tail(values2))..., + ) + +@inline unrolled_map(f::F, values) where {F} = + isempty(values) ? () : + (f(first(values)), unrolled_map(f, Base.tail(values))...) + +unrolled_foldl(f::F, values) where {F} = + isempty(values) ? + error("unrolled_foldl requires init for an empty collection of values") : + _unrolled_foldl(f, first(values), Base.tail(values)) +unrolled_foldl(f::F, values, init) where {F} = _unrolled_foldl(f, init, values) +@inline _unrolled_foldl(f::F, result, values) where {F} = + isempty(values) ? result : + _unrolled_foldl(f, f(result, first(values)), Base.tail(values)) + +# The @inline annotations are needed to avoid allocations when there are a lot +# of values. + +# Using first and tail instead of [1] and [2:end] restricts us to Tuples, but it +# also results in less compilation time. + +# This is required to make the unrolled functions type-stable, as of Julia 1.9. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(unrolled_zip) + m.recursion_relation = dont_limit + end + for m in methods(unrolled_map) + m.recursion_relation = dont_limit + end + for m in methods(_unrolled_foldl) + m.recursion_relation = dont_limit + end +end + +################################################################################ + +unrolled_foreach(f::F, values) where {F} = (unrolled_map(f, values); nothing) + +unrolled_any(f::F, values) where {F} = + unrolled_foldl(|, unrolled_map(f, values), false) + +unrolled_all(f::F, values) where {F} = + unrolled_foldl(&, unrolled_map(f, values), true) + +unrolled_filter(f::F, values) where {F} = + unrolled_foldl(values, ()) do filtered_values, value + f(value) ? (filtered_values..., value) : filtered_values + end + +unrolled_unique(values) = + unrolled_foldl(values, ()) do unique_values, value + unrolled_any(isequal(value), unique_values) ? unique_values : + (unique_values..., value) + end + +unrolled_flatten(values) = + unrolled_foldl(values, ()) do flattened_values, value + (flattened_values..., value...) + end + +# Non-standard functions: + +unrolled_mapflatten(f::F, values) where {F} = + unrolled_flatten(unrolled_map(f, values)) + +function unrolled_findonly(f::F, values) where {F} + filtered_values = unrolled_filter(f, values) + length(filtered_values) == 1 || + error("unrolled_findonly requires that exactly one value makes f true") + return first(filtered_values) +end + +# This is required to make functions defined elsewhere type-stable, as of Julia +# 1.9. Specifically, if an unrolled function is used to implement the recursion +# of another function, it needs to have its recursion limit disabled in order +# for that other function to be type-stable. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(unrolled_any) + m.recursion_relation = dont_limit + end # for is_valid_name + for m in methods(unrolled_mapflatten) + m.recursion_relation = dont_limit + end # for complement_values_in_subtree and value_or_non_overlapping_children +end diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index b87340799f..dfbbe79e5b 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -157,9 +157,6 @@ rconvert(::Type{T}, X) where {T} = Recursively scale each element of `X` by `Y`. """ rmul(X, Y) = rmap(*, X, Y) -rmul(w::Number, X) = rmap(x -> w * x, X) -rmul(X, w::Number) = rmap(x -> x * w, X) -rmul(w1::Number, w2::Number) = w1 * w2 const ⊠ = rmul """ diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl new file mode 100644 index 0000000000..f58ec23337 --- /dev/null +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -0,0 +1,381 @@ +import LinearAlgebra: I, norm +import ClimaCore.Utilities: half +import ClimaCore.RecursiveApply: ⊠ + +include("matrix_field_test_utils.jl") + +# This broadcast must be wrapped in a function to be tested with @test_opt. +field_matrix_mul!(b, A, x) = @. b = A * x + +function test_field_matrix_solver(; + test_name, + alg, + A, + b, + ignore_approximation_error = false, + skip_correctness_test = false, +) + @testset "$test_name" begin + x = similar(b) + b_test = similar(b) + solver = FieldMatrixSolver(alg, A, b) + args = (solver, x, A, b) + + solve_time = @benchmark field_matrix_solve!(args...) + mul_time = @benchmark field_matrix_mul!(b_test, A, x) + + solve_time_rounded = round(solve_time; sigdigits = 2) + mul_time_rounded = round(mul_time; sigdigits = 2) + time_ratio = solve_time_rounded / mul_time_rounded + time_ratio_rounded = round(time_ratio; sigdigits = 2) + + # If possible, test that A * (inv(A) * b) == b. + if skip_correctness_test + relative_error = + norm(abs.(parent(b_test) .- parent(b))) / norm(parent(b)) + relative_error_rounded = round(relative_error; sigdigits = 2) + error_string = "Relative Error = $(relative_error_rounded * 100) %" + else + if ignore_approximation_error + @assert alg isa MatrixFields.ApproximateFactorizationSolve + b_view = MatrixFields.field_vector_view(b) + A₁, A₂ = + MatrixFields.approximate_factors(alg.name_pairs₁, A, b_view) + @. b_test = A₁ * A₂ * x + end + max_error = maximum(abs.(parent(b_test) .- parent(b))) + max_eps_error = ceil(Int, max_error / eps(typeof(max_error))) + error_string = "Maximum Error = $max_eps_error eps" + end + + @info "$test_name:\n\tSolve Time = $solve_time_rounded s, \ + Multiplication Time = $mul_time_rounded s (Ratio = \ + $time_ratio_rounded)\n\t$error_string" + + skip_correctness_test || @test max_eps_error <= 3 + + @test_opt ignored_modules = ignore_cuda FieldMatrixSolver(alg, A, b) + @test_opt ignored_modules = ignore_cuda field_matrix_solve!(args...) + @test_opt ignored_modules = ignore_cuda field_matrix_mul!(b, A, x) + + using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0 + using_cuda || @test @allocated(field_matrix_mul!(b, A, x)) == 0 + end +end + +@testset "FieldMatrixSolver Unit Tests" begin + FT = Float64 + center_space, face_space = test_spaces(FT) + surface_space = Spaces.level(face_space, half) + + seed!(1) # ensures reproducibility + + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + sfc_vec = random_field(FT, surface_space) + + # Make each random square matrix diagonally dominant in order to avoid large + # large roundoff errors when computing its inverse. Scale the non-square + # matrices by the same amount as the square matrices. + λ = 10 # scale factor + ᶜᶜmat1 = random_field(DiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat1 = random_field(DiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + ᶜᶠmat2 = random_field(BidiagonalMatrixRow{FT}, center_space) ./ λ + ᶠᶜmat2 = random_field(BidiagonalMatrixRow{FT}, face_space) ./ λ + ᶜᶜmat3 = random_field(TridiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat3 = random_field(TridiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + ᶜᶠmat4 = random_field(QuaddiagonalMatrixRow{FT}, center_space) ./ λ + ᶠᶜmat4 = random_field(QuaddiagonalMatrixRow{FT}, face_space) ./ λ + ᶜᶜmat5 = random_field(PentadiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat5 = random_field(PentadiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + + for (vector, matrix, string1, string2) in ( + (sfc_vec, I, "UniformScaling", "a single level"), + (ᶜvec, I, "UniformScaling", "cell centers"), + (ᶠvec, I, "UniformScaling", "cell faces"), + (ᶜvec, ᶜᶜmat1, "diagonal matrix", "cell centers"), + (ᶠvec, ᶠᶠmat1, "diagonal matrix", "cell faces"), + (ᶜvec, ᶜᶜmat3, "tri-diagonal matrix", "cell centers"), + (ᶠvec, ᶠᶠmat3, "tri-diagonal matrix", "cell faces"), + (ᶜvec, ᶜᶜmat5, "penta-diagonal matrix", "cell centers"), + (ᶠvec, ᶠᶠmat5, "penta-diagonal matrix", "cell faces"), + ) + test_field_matrix_solver(; + test_name = "$string1 solve on $string2", + alg = MatrixFields.BlockDiagonalSolve(), + A = MatrixFields.FieldMatrix((@name(_), @name(_)) => matrix), + b = Fields.FieldVector(; _ = vector), + ) + end + + # TODO: Add a simple test where typeof(x) != typeof(b). + + for alg in ( + MatrixFields.BlockDiagonalSolve(), + MatrixFields.BlockLowerTriangularSolve(@name(c)), + MatrixFields.SchurComplementSolve(@name(f)), + MatrixFields.ApproximateFactorizationSolve((@name(c), @name(c))), + ) + test_field_matrix_solver(; + test_name = "$(typeof(alg).name.name) for a block diagonal matrix \ + with diagonal and penta-diagonal blocks", + alg, + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat1, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + ) + end + + test_field_matrix_solver(; + test_name = "BlockDiagonalSolve for a block diagonal matrix with \ + tri-diagonal and penta-diagonal blocks", + alg = MatrixFields.BlockDiagonalSolve(), + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat3, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + ) + + test_field_matrix_solver(; + test_name = "BlockLowerTriangularSolve for a block lower triangular \ + matrix with tri-diagonal, bi-diagonal, and penta-diagonal \ + blocks", + alg = MatrixFields.BlockLowerTriangularSolve(@name(c)), + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat3, + (@name(f), @name(c)) => ᶠᶜmat2, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + ) + + test_field_matrix_solver(; + test_name = "SchurComplementSolve for a block matrix with diagonal, \ + quad-diagonal, bi-diagonal, and penta-diagonal blocks", + alg = MatrixFields.SchurComplementSolve(@name(f)), + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat1, + (@name(c), @name(f)) => ᶜᶠmat4, + (@name(f), @name(c)) => ᶠᶜmat2, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + ) + + test_field_matrix_solver(; + test_name = "ApproximateFactorizationSolve for a block matrix with \ + tri-diagonal, quad-diagonal, bi-diagonal, and \ + penta-diagonal blocks", + alg = MatrixFields.ApproximateFactorizationSolve( + (@name(c), @name(c)); + alg₂ = MatrixFields.SchurComplementSolve(@name(f)), + ), + A = MatrixFields.FieldMatrix( + (@name(c), @name(c)) => ᶜᶜmat3, + (@name(c), @name(f)) => ᶜᶠmat4, + (@name(f), @name(c)) => ᶠᶜmat2, + (@name(f), @name(f)) => ᶠᶠmat5, + ), + b = Fields.FieldVector(; c = ᶜvec, f = ᶠvec), + ignore_approximation_error = true, + ) +end + +@testset "FieldMatrixSolver ClimaAtmos-Based Tests" begin + FT = Float64 + center_space, face_space = test_spaces(FT) + surface_space = Spaces.level(face_space, half) + + seed!(1) # ensures reproducibility + + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + sfc_vec = random_field(FT, surface_space) + + # Make each random square matrix diagonally dominant in order to avoid large + # large roundoff errors when computing its inverse. Scale the non-square + # matrices by the same amount as the square matrices. + λ = 10 # scale factor + ᶜᶜmat1 = random_field(DiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶜᶠmat2 = random_field(BidiagonalMatrixRow{FT}, center_space) ./ λ + ᶠᶜmat2 = random_field(BidiagonalMatrixRow{FT}, face_space) ./ λ + ᶜᶜmat3 = random_field(TridiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat3 = random_field(TridiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + + e¹² = Geometry.Covariant12Vector(1, 1) + e³ = Geometry.Covariant3Vector(1) + e₃ = Geometry.Contravariant3Vector(1) + + ρχ_unit = (; ρq_tot = 1, ρq_liq = 1, ρq_ice = 1, ρq_rai = 1, ρq_sno = 1) + ρaχ_unit = + (; ρaq_tot = 1, ρaq_liq = 1, ρaq_ice = 1, ρaq_rai = 1, ρaq_sno = 1) + + dry_center_gs_unit = (; ρ = 1, ρe_tot = 1, uₕ = e¹²) + center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit) + center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit) + + ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) + ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,) + ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',) + ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',) + ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',) + ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit)), ᶜᶜmat3) + ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit)), ᶜᶜmat3) + ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) + ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit ⊠ e₃')), ᶜᶠmat2) + # We need to use Fix1 and Fix2 instead of defining anonymous functions in + # order for the result of map to be inferrable. + + b_dry_dycore = Fields.FieldVector(; + c = ᶜvec .* (dry_center_gs_unit,), + f = ᶠvec .* ((; u₃ = e³),), + ) + + b_moist_dycore_diagnostic_edmf = Fields.FieldVector(; + c = ᶜvec .* (center_gs_unit,), + f = ᶠvec .* ((; u₃ = e³),), + ) + + b_moist_dycore_prognostic_edmf_prognostic_surface = Fields.FieldVector(; + sfc = sfc_vec .* ((; T = 1),), + c = ᶜvec .* ((; center_gs_unit..., sgsʲs = (center_sgsʲ_unit,)),), + f = ᶠvec .* ((; u₃ = e³, sgsʲs = ((; u₃ = e³),)),), + ) + + test_field_matrix_solver(; + test_name = "similar solve to ClimaAtmos's dry dycore with implicit \ + acoustic waves", + alg = MatrixFields.SchurComplementSolve(@name(f)), + A = MatrixFields.FieldMatrix( + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => I, + (@name(c.uₕ), @name(c.uₕ)) => I, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + ), + b = b_dry_dycore, + ) + + test_field_matrix_solver(; + test_name = "similar solve to ClimaAtmos's dry dycore with implicit \ + acoustic waves and diffusion", + alg = MatrixFields.ApproximateFactorizationSolve( + (@name(c), @name(f)), + (@name(f), @name(c)), + (@name(f), @name(f)); + alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + ), + A = MatrixFields.FieldMatrix( + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + ), + b = b_dry_dycore, + ignore_approximation_error = true, + ) + + test_field_matrix_solver(; + test_name = "similar solve to ClimaAtmos's moist dycore + diagnostic \ + EDMF with implicit acoustic waves and SGS fluxes", + alg = MatrixFields.ApproximateFactorizationSolve( + (@name(c), @name(f)), + (@name(f), @name(c)), + (@name(f), @name(f)); + alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + ), + A = MatrixFields.FieldMatrix( + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + ), + b = b_moist_dycore_diagnostic_edmf, + ignore_approximation_error = true, + ) + + # TODO: This unit test is currently broken. + # test_field_matrix_solver(; + # test_name = "similar solve to ClimaAtmos's moist dycore + prognostic \ + # EDMF + prognostic surface temperature with implicit \ + # acoustic waves and SGS fluxes", + # alg = MatrixFields.BlockLowerTriangularSolve( + # @name(c.sgsʲs), + # @name(f.sgsʲs); + # alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + # alg₂ = MatrixFields.ApproximateFactorizationSolve( + # (@name(c), @name(f)), + # (@name(f), @name(c)), + # (@name(f), @name(f)); + # alg₁ = MatrixFields.SchurComplementSolve(@name(f)), + # ), + # ), + # A = MatrixFields.FieldMatrix( + # # GS-GS blocks: + # (@name(sfc), @name(sfc)) => I, + # (@name(c.ρ), @name(c.ρ)) => I, + # (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + # (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, + # (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, + # (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, + # (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + # (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + # (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + # (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + # (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + # (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + # (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + # # GS-SGS blocks: + # (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, + # (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, + # (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, + # (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, + # (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, + # (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, + # (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + # (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + # (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, + # (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, + # (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + # (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + # (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, + # (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, + # (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, + # (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + # # SGS-SGS blocks: + # (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + # (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + # (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + # (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => + # ᶜᶠmat2_scalar_u₃, + # (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => + # ᶜᶠmat2_scalar_u₃, + # (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, + # (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => + # ᶠᶜmat2_u₃_scalar, + # (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => + # ᶠᶜmat2_u₃_scalar, + # (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + # ), + # b = b_moist_dycore_prognostic_edmf_prognostic_surface, + # skip_correctness_test = true, + # ) +end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl new file mode 100644 index 0000000000..facfb16b02 --- /dev/null +++ b/test/MatrixFields/field_names.jl @@ -0,0 +1,591 @@ +include("matrix_field_test_utils.jl") + +struct Foo{T} + _value::T +end +Base.propertynames(::Foo) = (:value,) +Base.getproperty(foo::Foo, s::Symbol) = + s == :value ? getfield(foo, :_value) : nothing + +const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) + +@testset "FieldName Unit Tests" begin + @test_all @name() == MatrixFields.FieldName() + @test_all @name(a.c.:(1).d) == MatrixFields.FieldName(:a, :c, 1, :d) + @test_all @name(a.c.:(3).:(1)) == MatrixFields.FieldName(:a, :c, 3, 1) + + @test_throws "not a valid property name" @macroexpand @name("a") + @test_throws "not a valid property name" @macroexpand @name([a]) + @test_throws "not a valid property name" @macroexpand @name((a.c.:(3)):(1)) + @test_throws "not a valid property name" @macroexpand @name(a.c.:(3).(1)) + + @test string(@name()) == "@name()" + @test string(@name(a.c.:(1).d)) == "@name(a.c.:(1).d)" + @test string(@name(a.c.:(3).:(1))) == "@name(a.c.:(3).:(1))" + + @test_all MatrixFields.has_field(x, @name()) + @test_all MatrixFields.has_field(x, @name(foo.value)) + @test_all MatrixFields.has_field(x, @name(a.b)) + @test_all MatrixFields.has_field(x, @name(a.c.:(1).d)) + @test_all !MatrixFields.has_field(x, @name(foo.invalid_name)) + + @test_all MatrixFields.get_field(x, @name()) == x + @test_all MatrixFields.get_field(x, @name(foo.value)) == 0 + @test_all MatrixFields.get_field(x, @name(a.b)) == 1 + @test_all MatrixFields.get_field(x, @name(a.c.:(1).d)) == 2 + + @test_all MatrixFields.broadcasted_has_field(typeof(x), @name()) + @test_all MatrixFields.broadcasted_has_field(typeof(x), @name(foo._value)) + @test_all MatrixFields.broadcasted_has_field(typeof(x), @name(a.b)) + @test_all MatrixFields.broadcasted_has_field(typeof(x), @name(a.c.:(1).d)) + @test_all !MatrixFields.broadcasted_has_field( + typeof(x), + @name(foo.invalid_name), + ) + + @test_all MatrixFields.broadcasted_get_field(x, @name()) == x + @test_all MatrixFields.broadcasted_get_field(x, @name(foo._value)) == 0 + @test_all MatrixFields.broadcasted_get_field(x, @name(a.b)) == 1 + @test_all MatrixFields.broadcasted_get_field(x, @name(a.c.:(1).d)) == 2 + + @test_all MatrixFields.is_child_name(@name(a.c.:(1).d), @name(a)) + @test_all !MatrixFields.is_child_name(@name(a.c.:(1).d), @name(foo)) + + @test_all MatrixFields.names_are_overlapping(@name(a), @name(a.c.:(1).d)) + @test_all MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(a)) + @test_all !MatrixFields.names_are_overlapping(@name(foo), @name(a.c.:(1).d)) + @test_all !MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(foo)) + + @test_all MatrixFields.extract_internal_name(@name(a.c.:(1).d), @name(a)) == + @name(c.:(1).d) + @test_throws "is not a child name" MatrixFields.extract_internal_name( + @name(a.c.:(1).d), + @name(foo), + ) + + @test_all MatrixFields.append_internal_name(@name(a), @name(c.:(1).d)) == + @name(a.c.:(1).d) + + @test_all MatrixFields.top_level_names(x) == (@name(foo), @name(a)) + @test_all MatrixFields.top_level_names(x.foo) == (@name(value),) + @test_all MatrixFields.top_level_names(x.a) == (@name(b), @name(c)) + @test_all MatrixFields.top_level_names(x.a.c) == + (@name(1), @name(2), @name(3)) +end + +@testset "FieldNameTree Unit Tests" begin + name_tree = MatrixFields.FieldNameTree(x) + + @test_all MatrixFields.FieldNameTree(x) == name_tree + + @test_all MatrixFields.is_valid_name(@name(), name_tree) + @test_all MatrixFields.is_valid_name(@name(foo.value), name_tree) + @test_all MatrixFields.is_valid_name(@name(a.b), name_tree) + @test_all MatrixFields.is_valid_name(@name(a.c.:(1).d), name_tree) + @test_all !MatrixFields.is_valid_name(@name(foo.invalid_name), name_tree) + + @test_all MatrixFields.child_names(@name(), name_tree) == + (@name(foo), @name(a)) + @test_all MatrixFields.child_names(@name(foo), name_tree) == + (@name(foo.value),) + @test_all MatrixFields.child_names(@name(a), name_tree) == + (@name(a.b), @name(a.c)) + @test_all MatrixFields.child_names(@name(a.c), name_tree) == + (@name(a.c.:(1)), @name(a.c.:(2)), @name(a.c.:(3))) + @test_throws "does not contain any child names" MatrixFields.child_names( + @name(a.c.:(2)), + name_tree, + ) + @test_throws "does not contain the name" MatrixFields.child_names( + @name(foo.invalid_name), + name_tree, + ) +end + +@testset "FieldNameSet Unit Tests" begin + name_tree = MatrixFields.FieldNameTree(x) + vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) + matrix_keys(name_pairs...) = + MatrixFields.FieldMatrixKeys(name_pairs, name_tree) + + vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) + matrix_keys_no_tree(name_pairs...) = + MatrixFields.FieldMatrixKeys(name_pairs) + + @testset "FieldNameSet Construction" begin + @test_throws "Invalid FieldNameSet value" vector_keys( + @name(foo.invalid_name), + ) + @test_throws "Invalid FieldNameSet value" matrix_keys(( + @name(foo.invalid_name), + @name(a.c), + ),) + + for constructor in (vector_keys, vector_keys_no_tree) + @test_throws "Duplicate FieldNameSet values" constructor( + @name(foo), + @name(foo), + ) + @test_throws "Overlapping FieldNameSet values" constructor( + @name(foo), + @name(foo.value), + ) + end + for constructor in (matrix_keys, matrix_keys_no_tree) + @test_throws "Duplicate FieldNameSet values" constructor( + (@name(foo.value), @name(a.c)), + (@name(foo.value), @name(a.c)), + ) + @test_throws "Overlapping FieldNameSet values" constructor( + (@name(foo), @name(a.c)), + (@name(foo.value), @name(a.c)), + ) + end + end + + @testset "FieldNameSet Iteration" begin + v_set1 = vector_keys(@name(foo), @name(a.c)) + v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + m_set1_no_tree = matrix_keys_no_tree( + (@name(foo), @name(a.c)), + (@name(a.b), @name(foo)), + ) + + @test_all map(name -> (name, name), v_set1) == + ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) + @test_all map(name_pair -> name_pair[1], m_set1) == + (@name(foo), @name(a.b)) + + @test_all isnothing(foreach(name -> (name, name), v_set1)) + @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) + + @test string(v_set1) == + "FieldVectorKeys(@name(foo), @name(a.c); )" + @test string(v_set1_no_tree) == + "FieldVectorKeys(@name(foo), @name(a.c))" + @test string(m_set1) == "FieldMatrixKeys((@name(foo), @name(a.c)), \ + (@name(a.b), @name(foo)); )" + @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ + @name(a.c)), (@name(a.b), @name(foo)))" + + for set in (v_set1, v_set1_no_tree) + @test_all @name(foo) in set + @test_all !(@name(a.b) in set) + @test_all !(@name(invalid_name) in set) + end + for set in (m_set1, m_set1_no_tree) + @test_all (@name(foo), @name(a.c)) in set + @test_all !((@name(foo), @name(a.b)) in set) + @test_all !((@name(foo), @name(invalid_name)) in set) + end + + @test_all @name(foo.value) in v_set1 + @test_all !(@name(foo.invalid_name) in v_set1) + @test_throws "FieldNameTree" @name(foo.value) in v_set1_no_tree + @test_throws "FieldNameTree" @name(foo.invalid_name) in v_set1_no_tree + + @test_all (@name(foo.value), @name(a.c)) in m_set1 + @test_all !((@name(foo.invalid_name), @name(a.c)) in m_set1) + @test_throws "FieldNameTree" (@name(foo.value), @name(a.c)) in + m_set1_no_tree + @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in + m_set1_no_tree + end + + @testset "FieldNameSet Operations for Addition/Subtraction" begin + v_set1 = vector_keys(@name(foo), @name(a.c)) + v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + m_set1_no_tree = matrix_keys_no_tree( + (@name(foo), @name(a.c)), + (@name(a.b), @name(foo)), + ) + + v_set2 = vector_keys(@name(foo)) + v_set2_no_tree = vector_keys_no_tree(@name(foo)) + m_set2 = matrix_keys((@name(foo), @name(a.c))) + m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) + + v_set3 = vector_keys( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + v_set3_no_tree = vector_keys_no_tree( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + m_set3 = matrix_keys( + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo)), + ) + m_set3_no_tree = matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo)), + ) + m_set3_no_tree′ = matrix_keys_no_tree( + (@name(foo.value), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(2))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo)), + ) + + for (set1, set2) in ( + (v_set1, v_set2), + (v_set1, v_set2_no_tree), + (v_set1_no_tree, v_set2), + (m_set1, m_set2), + (m_set1, m_set2_no_tree), + (m_set1_no_tree, m_set2), + ) + @test_all set1 != set2 + @test_all !issubset(set1, set2) + @test_all issubset(set2, set1) + @test_all intersect(set1, set2) == set2 + @test_all union(set1, set2) == set1 + @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) + @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) + end + + for (set1, set2) in + ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) + @test_all set1 != set2 + @test_all !issubset(set1, set2) + @test_all issubset(set2, set1) + @test_all intersect(set1, set2) == set2 + @test_all union(set1, set2) == set1 + @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) + @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( + set2, + set1, + ) + end + + for (set1, set3) in ( + (v_set1, v_set3), + (v_set1, v_set3_no_tree), + (v_set1_no_tree, v_set3), + ) + @test_all set1 != set3 + @test_all !issubset(set1, set3) + @test_all issubset(set3, set1) + @test_all intersect(set1, set3) == set3 + @test_all union(set1, set3) == set3 + @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) + @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + end + + for (set1, set3) in ( + (m_set1, m_set3), + (m_set1, m_set3_no_tree), + (m_set1_no_tree, m_set3), + ) + @test_all set1 != set3 + @test_all !issubset(set1, set3) + @test_all issubset(set3, set1) + @test_all intersect(set1, set3) == m_set3_no_tree′ + @test_all union(set1, set3) == m_set3_no_tree′ + @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) + @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + end + + for (set1, set3) in + ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) + @test_all set1 != set3 + @test_all !issubset(set1, set3) + @test_throws "FieldNameTree" issubset(set3, set1) + @test_throws "FieldNameTree" intersect(set1, set3) == set3 + @test_throws "FieldNameTree" union(set1, set3) == set3 + @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) + @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( + set3, + set1, + ) + end + end + + @testset "FieldNameSet Operations for Matrix Multiplication" begin + # With one exception, none of the following operations require a + # FieldNameTree. + + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c))), + vector_keys_no_tree(@name(a.c)), + ) == vector_keys_no_tree(@name(foo)) + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c))), + matrix_keys_no_tree((@name(a.c), @name(a.b))), + ) == matrix_keys_no_tree((@name(foo), @name(a.b))) + + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c.:(1)))), + vector_keys_no_tree(@name(a.c)), + ) == vector_keys_no_tree(@name(foo)) + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c.:(1)))), + matrix_keys_no_tree((@name(a.c), @name(a.b))), + ) == matrix_keys_no_tree((@name(foo), @name(a.b))) + + @test_throws "extract internal column" MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c))), + vector_keys_no_tree(@name(a.c.:(1))), + ) + @test_throws "extract internal column" MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(foo), @name(a.c))), + matrix_keys_no_tree((@name(a.c.:(1)), @name(a.b))), + ) + + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(a.c), @name(a.c))), + vector_keys_no_tree(@name(a.c.:(1))), + ) == vector_keys_no_tree(@name(a.c.:(1))) + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree((@name(a.c), @name(a.c))), + matrix_keys_no_tree((@name(a.c.:(1)), @name(a.b))), + ) == matrix_keys_no_tree((@name(a.c.:(1)), @name(a.b))) + + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(a.c.:(1)), @name(foo)), + (@name(a.c), @name(a.c)), + ), + vector_keys_no_tree(@name(foo), @name(a.c.:(1))), + ) == vector_keys_no_tree(@name(a.c.:(1))) + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(a.c.:(1)), @name(foo)), + (@name(a.c), @name(a.c)), + ), + matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(a.c.:(1)), @name(a.b)), + ), + ) == matrix_keys_no_tree((@name(a.c.:(1)), @name(a.b))) + + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + vector_keys(@name(foo), @name(a.c)), + ) == vector_keys_no_tree(@name(foo.value)) + @test_all MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + matrix_keys((@name(foo), @name(a.b)), (@name(a.c), @name(a.b))), + ) == matrix_keys_no_tree((@name(foo.value), @name(a.b))) + + @test_throws "FieldNameTree" MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + vector_keys_no_tree(@name(foo), @name(a.c)), + ) == vector_keys_no_tree(@name(foo.value)) + @test_throws "FieldNameTree" MatrixFields.matrix_product_keys( + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(a.c), @name(a.b)), + ), + ) == matrix_keys_no_tree(( + @name(foo.value), + @name(a.b), + )) + + @test_all MatrixFields.summand_names_for_matrix_product( + @name(foo), + matrix_keys_no_tree((@name(foo), @name(a.c))), + vector_keys_no_tree(@name(a.c)), + ) == vector_keys_no_tree(@name(a.c)) + @test_all MatrixFields.summand_names_for_matrix_product( + (@name(foo), @name(a.b)), + matrix_keys_no_tree((@name(foo), @name(a.c))), + matrix_keys_no_tree((@name(a.c), @name(a.b))), + ) == vector_keys_no_tree(@name(a.c)) + + @test_all MatrixFields.summand_names_for_matrix_product( + @name(foo), + matrix_keys_no_tree((@name(foo), @name(a.c.:(1)))), + vector_keys_no_tree(@name(a.c)), + ) == vector_keys_no_tree(@name(a.c.:(1))) + @test_all MatrixFields.summand_names_for_matrix_product( + (@name(foo), @name(a.b)), + matrix_keys_no_tree((@name(foo), @name(a.c.:(1)))), + matrix_keys_no_tree((@name(a.c), @name(a.b))), + ) == vector_keys_no_tree(@name(a.c.:(1))) + + @test_all MatrixFields.summand_names_for_matrix_product( + @name(a.c.:(1)), + matrix_keys_no_tree((@name(a.c), @name(a.c))), + vector_keys_no_tree(@name(a.c.:(1))), + ) == vector_keys_no_tree(@name(a.c.:(1))) + @test_all MatrixFields.summand_names_for_matrix_product( + (@name(a.c.:(1)), @name(a.b)), + matrix_keys_no_tree((@name(a.c), @name(a.c))), + matrix_keys_no_tree((@name(a.c.:(1)), @name(a.b))), + ) == vector_keys_no_tree(@name(a.c.:(1))) + + @test_all MatrixFields.summand_names_for_matrix_product( + @name(a.c.:(1)), + matrix_keys_no_tree( + (@name(a.c.:(1)), @name(foo)), + (@name(a.c), @name(a.c)), + ), + vector_keys_no_tree(@name(foo), @name(a.c.:(1))), + ) == vector_keys_no_tree(@name(foo), @name(a.c.:(1))) + @test_all MatrixFields.summand_names_for_matrix_product( + (@name(a.c.:(1)), @name(a.b)), + matrix_keys_no_tree( + (@name(a.c.:(1)), @name(foo)), + (@name(a.c), @name(a.c)), + ), + matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(a.c.:(1)), @name(a.b)), + ), + ) == vector_keys_no_tree(@name(foo), @name(a.c.:(1))) + + @test_all MatrixFields.summand_names_for_matrix_product( + @name(foo.value), + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + vector_keys_no_tree(@name(foo), @name(a.c)), + ) == vector_keys_no_tree(@name(foo.value), @name(a.c.:(1))) + @test_all MatrixFields.summand_names_for_matrix_product( + (@name(foo.value), @name(a.b)), + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(foo.value)), + ), + matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(a.c), @name(a.b)), + ), + ) == vector_keys_no_tree(@name(foo.value), @name(a.c.:(1))) + end + + @testset "Other FieldNameSet Operations" begin + v_set1 = vector_keys(@name(foo), @name(a.c)) + v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + m_set1_no_tree = matrix_keys_no_tree( + (@name(foo), @name(a.c)), + (@name(a.b), @name(foo)), + ) + + v_set2 = vector_keys(@name(foo.value), @name(a.c.:(1)), @name(a.c.:(3))) + v_set2_no_tree = vector_keys_no_tree( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(3)) + ) + m_set2 = matrix_keys( + (@name(foo), @name(foo)), + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + (@name(a), @name(a.c)), + ) + m_set2_no_tree = matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + (@name(a), @name(a.c)), + ) + + @test_all MatrixFields.set_complement(v_set2) == + vector_keys(@name(a.b), @name(a.c.:(2))) + @test_throws "FieldNameTree" MatrixFields.set_complement(v_set2_no_tree) + + @test_all MatrixFields.set_complement(m_set2) == matrix_keys( + (@name(foo.value), @name(a.b)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a.c), @name(foo.value)), + (@name(a), @name(a.b)), + ) + @test_throws "FieldNameTree" MatrixFields.set_complement(m_set2_no_tree) + + for (set1, set2) in ( + (v_set1, v_set2), + (v_set1, v_set2_no_tree), + (v_set1_no_tree, v_set2), + ) + @test_all setdiff(set1, set2) == vector_keys(@name(a.c.:(2))) + end + + for (set1, set2) in ( + (m_set1, m_set2), + (m_set1, m_set2_no_tree), + (m_set1_no_tree, m_set2), + ) + @test_all setdiff(set1, set2) == + matrix_keys((@name(foo.value), @name(a.c.:(2)))) + end + + for (set1, set2) in + ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) + @test_throws "FieldNameTree" setdiff(set1, set2) + end + + # With one exception, none of the following operations require a + # FieldNameTree. + + @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == + matrix_keys( + (@name(foo), @name(foo)), + (@name(a.c), @name(a.c)), + ) + + @test_all MatrixFields.cartesian_product( + v_set1_no_tree, + v_set2_no_tree, + ) == matrix_keys( + (@name(foo), @name(foo.value)), + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(3))), + (@name(a.c), @name(foo.value)), + (@name(a.c), @name(a.c.:(1))), + (@name(a.c), @name(a.c.:(3))), + ) + + @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == + vector_keys(@name(foo), @name(a.b)) + + @test_all MatrixFields.matrix_row_keys(m_set2) == + vector_keys(@name(foo.value), @name(a.b), @name(a.c)) + @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( + m_set2_no_tree, + ) + + @test_all MatrixFields.matrix_off_diagonal_keys(m_set2_no_tree) == + matrix_keys( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + (@name(a), @name(a.c)), + ) + + @test_all MatrixFields.matrix_diagonal_keys(m_set2_no_tree) == + matrix_keys( + (@name(foo), @name(foo)), + (@name(a.c), @name(a.c)), + ) + end +end diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index 6c7e5a2119..e71ed5b999 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -34,9 +34,9 @@ macro benchmark(expression) end end -const ignore_cuda = (AnyFrameModule(CUDA),) - -is_using_cuda() = ClimaComms.device() isa ClimaComms.CUDADevice +const comms_device = ClimaComms.device() +const using_cuda = comms_device isa ClimaComms.CUDADevice +const ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () # Test the allocating and non-allocating versions of a field broadcast against # a reference non-allocating implementation. Ensure that they are performant, @@ -53,7 +53,7 @@ function test_field_broadcast(; test_broken_with_cuda = false, ) where {F1, F2, F3} @testset "$test_name" begin - if test_broken_with_cuda && is_using_cuda() + if test_broken_with_cuda && using_cuda @test_throws CUDA.InvalidIRError get_result() @warn "$test_name:\n\tCUDA.InvalidIRError" return @@ -94,14 +94,13 @@ function test_field_broadcast(; # the allocations they incur. @test_opt ignored_modules = ignore_cuda get_result() @test_opt ignored_modules = ignore_cuda set_result!(result) - @test is_using_cuda() || (@allocated set_result!(result)) == 0 + using_cuda || @test (@allocated set_result!(result)) == 0 if !isnothing(ref_set_result!) # Test ref_set_result! for type instabilities and allocations to # ensure that the performance comparison is fair. @test_opt ignored_modules = ignore_cuda ref_set_result!(ref_result) - @test is_using_cuda() || - (@allocated ref_set_result!(ref_result)) == 0 + using_cuda || @test (@allocated ref_set_result!(ref_result)) == 0 end end end @@ -124,7 +123,7 @@ function test_field_broadcast_against_array_reference(; test_broken_with_cuda = false, ) where {F1, F2, F3} @testset "$test_name" begin - if test_broken_with_cuda && is_using_cuda() + if test_broken_with_cuda && using_cuda @test_throws CUDA.InvalidIRError get_result() @warn "$test_name:\n\tCUDA.InvalidIRError" return @@ -177,12 +176,12 @@ function test_field_broadcast_against_array_reference(; # the allocations they incur. @test_opt ignored_modules = ignore_cuda get_result() @test_opt ignored_modules = ignore_cuda set_result!(result) - @test is_using_cuda() || (@allocated set_result!(result)) == 0 + using_cuda || @test (@allocated set_result!(result)) == 0 # Test ref_set_result! for type instabilities and allocations to ensure # that the performance comparison is fair. @test_opt ignored_modules = ignore_cuda call_ref_set_result!() - @test is_using_cuda() || (@allocated call_ref_set_result!()) == 0 + using_cuda || @test (@allocated call_ref_set_result!()) == 0 end end @@ -192,7 +191,7 @@ function test_spaces(::Type{FT}) where {FT} velem = 20 # This should be big enough to test high-bandwidth matrices. helem = npoly = 1 # These should be small enough for the tests to be fast. - comms_ctx = ClimaComms.SingletonCommsContext() + comms_ctx = ClimaComms.SingletonCommsContext(comms_device) hdomain = Domains.SphereDomain(FT(10)) hmesh = Meshes.EquiangularCubedSphere(hdomain, helem) htopology = Topologies.Topology2D(comms_ctx, hmesh) @@ -208,7 +207,7 @@ function test_spaces(::Type{FT}) where {FT} vspace = Spaces.CenterFiniteDifferenceSpace(vtopology) sfc_coord = Fields.coordinate_field(hspace) hypsography = - is_using_cuda() ? Hypsography.Flat() : + using_cuda ? Hypsography.Flat() : Hypsography.LinearAdaption( @. cosd(sfc_coord.lat) + cosd(sfc_coord.long) + 1 ) # TODO: FD operators don't currently work with hypsography on GPUs. diff --git a/test/runtests.jl b/test/runtests.jl index 43b43f52a0..fc6b304141 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,9 +78,11 @@ if !Sys.iswindows() @safetestset "MatrixFields - rmul_with_projection" begin @time include("MatrixFields/rmul_with_projection.jl") end @safetestset "MatrixFields - field2arrays" begin @time include("MatrixFields/field2arrays.jl") end @safetestset "MatrixFields - matrix multiplication at boundaries" begin @time include("MatrixFields/matrix_multiplication_at_boundaries.jl") end + @safetestset "MatrixFields - field names" begin @time include("MatrixFields/field_names.jl") end # now part of buildkite # @safetestset "MatrixFields - matrix field broadcasting" begin @time include("MatrixFields/matrix_field_broadcasting.jl") end # @safetestset "MatrixFields - operator matrices" begin @time include("MatrixFields/operator_matrices.jl") end + # @safetestset "MatrixFields - field matrix solvers" begin @time include("MatrixFields/field_matrix_solvers.jl") end @safetestset "Hypsography - 2d" begin @time include("Hypsography/2d.jl") end @safetestset "Hypsography - 3d sphere" begin @time include("Hypsography/3dsphere.jl") end