Skip to content

Commit

Permalink
Merge pull request #1653 from CliMA/ck/fused_single_field_solve
Browse files Browse the repository at this point in the history
Parallelize single_field_solve -> multiple_field_solve
  • Loading branch information
charleskawczynski authored Apr 5, 2024
2 parents 033e5b0 + 31e69ef commit d0729e4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 16 deletions.
5 changes: 5 additions & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Adapt.adapt_structure(to, field::Field) = Field(
const PointField{V, S} =
Field{V, S} where {V <: AbstractData, S <: Spaces.PointSpace}

# TODO: do we need to make this distinction? what about inside cuda kernels
# when we replace with a PlaceHolerSpace?
const PointDataField{V, S} =
Field{V, S} where {V <: DataLayouts.DataF, S <: Spaces.AbstractSpace}

# Spectral Element Field
const SpectralElementField{V, S} = Field{
V,
Expand Down
3 changes: 3 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import BandedMatrices: BandedMatrix, band, _BandedMatrix
import RecursiveArrayTools: recursive_bottom_eltype
import KrylovKit
import ClimaComms
import Adapt

import ..Utilities: PlusHalf, half
import ..RecursiveApply:
Expand Down Expand Up @@ -86,6 +87,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
S <: Union{
Spaces.FiniteDifferenceSpace,
Spaces.ExtrudedFiniteDifferenceSpace,
Operators.PlaceholderSpace, # so that this can exist inside cuda kernels
},
}

Expand All @@ -99,6 +101,7 @@ include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
include("single_field_solver.jl")
include("multiple_field_solver.jl")
include("field_matrix_solver.jl")
include("field_matrix_iterative_solver.jl")

Expand Down
10 changes: 7 additions & 3 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,13 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
end

run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) =
foreach(matrix_row_keys(keys(A))) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
multiple_field_solve!(cache, x, A, b)

# This may be helpful for debugging:
# run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) =
# foreach(matrix_row_keys(keys(A))) do name
# single_field_solve!(cache[name], x[name], A[name, name], b[name])
# end

"""
BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂])
Expand Down
16 changes: 16 additions & 0 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ function Base.show(io::IO, dict::FieldNameDict)
end
end

function Operators.strip_space(dict::FieldNameDict)
vals = unrolled_map(values(dict)) do val
if val isa Fields.Field
Fields.Field(Fields.field_values(val), Operators.PlaceholderSpace())
else
val
end
end
FieldNameDict(keys(dict), vals)
end

function Adapt.adapt_structure(to, dict::FieldNameDict)
vals = unrolled_map(v -> Adapt.adapt_structure(to, v), values(dict))
FieldNameDict(keys(dict), vals)
end

Base.keys(dict::FieldNameDict) = dict.keys

Base.values(dict::FieldNameDict) = dict.entries
Expand Down
121 changes: 121 additions & 0 deletions src/MatrixFields/multiple_field_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# TODO: Can different A's be different matrix styles?
# if so, how can we handle fuse/parallelize?

# First, dispatch based on the first x and the device:
function multiple_field_solve!(cache, x, A, b)
name1 = first(matrix_row_keys(keys(A)))
x1 = x[name1]
multiple_field_solve!(ClimaComms.device(axes(x1)), cache, x, A, b, x1)
end

# TODO: fuse/parallelize
function multiple_field_solve!(
::ClimaComms.AbstractCPUDevice,
cache,
x,
A,
b,
x1,
)
foreach(matrix_row_keys(keys(A))) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
end

function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1)
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
names = matrix_row_keys(keys(A))
Nnames = length(names)
nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh * Nnames)
sscache = Operators.strip_space(cache)
ssx = Operators.strip_space(x)
ssA = Operators.strip_space(A)
ssb = Operators.strip_space(b)
cache_tup = map(name -> sscache[name], names)
x_tup = map(name -> ssx[name], names)
A_tup = map(name -> ssA[name, name], names)
b_tup = map(name -> ssb[name], names)
x1 = first(x_tup)

tups = (cache_tup, x_tup, A_tup, b_tup)

device = ClimaComms.device(x[first(names)])
CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!(
device,
tups,
x1,
Val(Nnames),
)
end

function get_ijhn(Ni, Nj, Nh, Nnames, blockIdx, threadIdx, blockDim, gridDim)
tidx = (blockIdx.x - 1) * blockDim.x + threadIdx.x
(i, j, h, n) = if 1 tidx prod((Ni, Nj, Nh, Nnames))
CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nnames))[tidx].I
else
(-1, -1, -1, -1)
end
return (i, j, h, n)
end

column_A(A::UniformScaling, i, j, h) = A
column_A(A, i, j, h) = Spaces.column(A, i, j, h)

@inline function _recurse(js::Tuple, tups::Tuple, transform, device, i::Int)
if first(js) == i
tup_args = map(x -> transform(first(x)), tups)
_single_field_solve!(tup_args..., device)
end
_recurse(Base.tail(js), map(x -> Base.tail(x), tups), transform, device, i)
end

@inline _recurse(js::Tuple{}, tups::Tuple, transform, device, i::Int) = nothing

@inline function _recurse(
js::Tuple{Int},
tups::Tuple,
transform,
device,
i::Int,
)
if first(js) == i
tup_args = map(x -> transform(first(x)), tups)
_single_field_solve!(tup_args..., device)
end
return nothing
end

function multiple_field_solve_kernel!(
device::ClimaComms.CUDADevice,
tups,
x1,
::Val{Nnames},
) where {Nnames}
@inbounds begin
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
(i, j, h, iname) = get_ijhn(
Ni,
Nj,
Nh,
Nnames,
CUDA.blockIdx(),
CUDA.threadIdx(),
CUDA.blockDim(),
CUDA.gridDim(),
)
if 1 i <= Ni && 1 j Nj && 1 h Nh && 1 iname Nnames

nt = ntuple-> ξ, Val(Nnames))
_recurse(nt, tups, ξ -> column_A(ξ, i, j, h), device, iname)
# _recurse effectively calls
# _single_field_solve!(
# Spaces.column(caches[iname], i, j, h),
# Spaces.column(xs[iname], i, j, h),
# column_A(As[iname], i, j, h),
# Spaces.column(bs[iname], i, j, h),
# device,
# )
end
end
return nothing
end
108 changes: 98 additions & 10 deletions src/MatrixFields/single_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,31 @@ single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) =
single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b)

single_field_solve!(::ClimaComms.AbstractCPUDevice, cache, x, A, b) =
_single_field_solve!(cache, x, A, b)
_single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b)

# single_field_solve!(::ClimaComms.CUDADevice, ...) is no longer exercised,
# but it may be helpful for debugging, due to its simplicity. So, let's leave
# it here for now.
function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh)
device = ClimaComms.device(A)
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks single_field_solve_kernel!(
device,
cache,
x,
A,
b,
)
end

function single_field_solve_kernel!(cache, x, A, b)
function single_field_solve_kernel!(device, cache, x, A, b)
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
_single_field_solve!(
device,
Spaces.column(cache, i, j, h),
Spaces.column(x, i, j, h),
Spaces.column(A, i, j, h),
Expand All @@ -80,22 +87,103 @@ single_field_solve_kernel!(
b::Fields.ColumnField,
) = _single_field_solve!(cache, x, A, b)

_single_field_solve!(cache, x, A, b) =
# CPU (GPU has already called Spaces.column on arg)
_single_field_solve!(device::ClimaComms.AbstractCPUDevice, cache, x, A, b) =
Fields.bycolumn(axes(A)) do colidx
_single_field_solve!(cache[colidx], x[colidx], A[colidx], b[colidx])
_single_field_solve_col!(
ClimaComms.device(axes(A)),
cache[colidx],
x[colidx],
A[colidx],
b[colidx],
)
end

function _single_field_solve_col!(
::ClimaComms.AbstractCPUDevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
A,
b::Fields.ColumnField,
)
if A isa Fields.ColumnField
band_matrix_solve!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
)
elseif A isa UniformScaling
x .= inv(A.λ) .* b
else
error("uncaught case")
end
end

# called by TuplesOfNTuples.jl's `inner_dispatch`:
# which requires a particular argument order:
_single_field_solve!(
cache::Fields.Field,
x::Fields.Field,
A::Union{Fields.Field, UniformScaling},
b::Fields.Field,
dev::ClimaComms.CUDADevice,
) = _single_field_solve!(dev, cache, x, A, b)

_single_field_solve!(
cache::Fields.Field,
x::Fields.Field,
A::Union{Fields.Field, UniformScaling},
b::Fields.Field,
dev::ClimaComms.AbstractCPUDevice,
) = _single_field_solve_col!(dev, cache, x, A, b)

function _single_field_solve!(
::ClimaComms.CUDADevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
A::Fields.ColumnField,
b::Fields.ColumnField,
) = band_matrix_solve!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
)
band_matrix_solve!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
)
end

function _single_field_solve!(
::ClimaComms.CUDADevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
A::UniformScaling,
b::Fields.ColumnField,
)
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds for i in 1:n
x_data[i] = inv(A.λ) b_data[i]
end
end

function _single_field_solve!(
::ClimaComms.CUDADevice,
cache::Fields.PointDataField,
x::Fields.PointDataField,
A::UniformScaling,
b::Fields.PointDataField,
)
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds begin
x_data[] = inv(A.λ) b_data[]
end
end

unzip_tuple_field_values(data) =
ntuple(i -> data.:($i), Val(length(propertynames(data))))
Expand Down
8 changes: 5 additions & 3 deletions test/MatrixFields/matrix_field_test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ macro benchmark(expression)
end
end

const comms_device = ClimaComms.device()
const using_cuda = comms_device isa ClimaComms.CUDADevice
const ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : ()
comms_device = ClimaComms.device()
# comms_device = ClimaComms.CPUSingleThreaded()
@show comms_device
using_cuda = comms_device isa ClimaComms.CUDADevice
ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : ()

# Test the allocating and non-allocating versions of a field broadcast against
# a reference non-allocating implementation. Ensure that they are performant,
Expand Down

0 comments on commit d0729e4

Please sign in to comment.