Skip to content

Commit

Permalink
Add FieldMatrix and linear solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Oct 10, 2023
1 parent 2706600 commit b7d39ea
Show file tree
Hide file tree
Showing 17 changed files with 2,750 additions and 29 deletions.
19 changes: 18 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ steps:
soft_fail: true
agents:
slurm_gpus: 1
slurm_mem: 40GB
slurm_mem: 80GB

- label: "Unit: operator matrices (CPU)"
key: unit_operator_matrices_cpu
Expand All @@ -570,6 +570,23 @@ steps:
slurm_gpus: 1
slurm_mem: 40GB

- label: "Unit: field names"
key: unit_field_names
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_names.jl"

- label: "Unit: field matrix solvers (CPU)"
key: unit_field_matrix_solvers_cpu
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_matrix_solvers.jl"
agents:
slurm_mem: 40GB

- label: "Unit: field matrix solvers (GPU)"
key: unit_field_matrix_solvers_gpu
command: "julia --color=yes --project=test test/MatrixFields/field_matrix_solvers.jl"
agents:
slurm_gpus: 1
slurm_mem: 40GB

- group: "Unit: Hypsography"
steps:

Expand Down
18 changes: 18 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ MultiplyColumnwiseBandMatrixField
operator_matrix
```

# Linear Solvers

```@docs
FieldMatrixSolverAlgorithm
FieldMatrixSolver
field_matrix_solve!
BlockDiagonalSolve
BlockLowerTriangularSolve
SchurComplementSolve
ApproximateFactorizationSolve
```

## Internals

```@docs
Expand All @@ -39,6 +51,12 @@ matrix_shape
column_axes
AbstractLazyOperator
replace_lazy_operator
FieldName
@name
FieldNameTree
FieldNameSet
FieldNameDict
field_vector_view
```

## Utilities
Expand Down
36 changes: 26 additions & 10 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,41 @@ for them:
`BandedMatrices.jl`
- Custom printing, e.g., `matrix_field` gets displayed as a `BandedMatrix`,
specifically, as the `BandedMatrix` that corresponds to its first column
TODO: Add comments on `operator_matrix`, `FieldMatrix`, and `FieldMatrixSolver`.
"""
module MatrixFields

import CUDA: @allowscalar
import LinearAlgebra: UniformScaling, Adjoint, AdjointAbsVec
import CUDA
import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv
import StaticArrays: SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import ClimaComms

import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: , ,
import ..DataLayouts: AbstractData
import ..Geometry
import ..Spaces
import ..Fields
import ..Operators

export
export DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow
export FieldVectorKeys, FieldVectorView, FieldVectorViewBroadcasted
export FieldMatrixKeys, FieldMatrix, FieldMatrixBroadcasted
export @name, , FieldMatrixSolver, field_matrix_solve!

# Types that are teated as single values when using matrix fields.
const SingleValue =
Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor}

include("band_matrix_row.jl")
include("rmul_with_projection.jl")
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")

const ColumnwiseBandMatrixField{V, S} = Fields.Field{
V,
Expand All @@ -71,6 +71,19 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
},
}

include("rmul_with_projection.jl")
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")
include("unrolled_functions.jl")
include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
include("field_matrix_solver.jl")
include("single_field_solver.jl")

function Base.show(io::IO, field::ColumnwiseBandMatrixField)
print(io, eltype(field), "-valued Field")
if eltype(eltype(field)) <: Number
Expand All @@ -82,7 +95,10 @@ function Base.show(io::IO, field::ColumnwiseBandMatrixField)
end
column_field = Fields.column(field, 1, 1, 1)
io = IOContext(io, :compact => true, :limit => true)
@allowscalar Base.print_array(io, column_field2array_view(column_field))
CUDA.@allowscalar Base.print_array(
io,
column_field2array_view(column_field),
)
else
# When a BandedMatrix with non-number entries is printed, it currently
# either prints in an illegible format (e.g., if it has AxisTensor or
Expand Down
6 changes: 6 additions & 0 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ Base.:*(value::SingleValue, row::BandMatrixRow) =

Base.:/(row::BandMatrixRow, value::Number) =
map(entry -> rdiv(entry, value), row)

inv(row::DiagonalMatrixRow) = DiagonalMatrixRow(inv(row[0]))
inv(::BandMatrixRow{ld, bw}) where {ld, bw} = error(
"The inverse of a matrix with $bw diagonals is (usually) a dense matrix, \
so it cannot be represented using BandMatrixRows",
)
4 changes: 2 additions & 2 deletions src/MatrixFields/field2arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ function column_field2array(field::Fields.FiniteDifferenceField)
last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d

diagonal_data_view = view(diagonal_data, first_row:last_row)
@allowscalar copyto!(matrix_diagonal, diagonal_data_view)
CUDA.@allowscalar copyto!(matrix_diagonal, diagonal_data_view)
end
return matrix
else # field represents a vector
return @allowscalar Array(column_field2array_view(field))
return CUDA.@allowscalar Array(column_field2array_view(field))
end
end

Expand Down
Loading

0 comments on commit b7d39ea

Please sign in to comment.