Skip to content

Commit

Permalink
Introduce post-operation callback
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 9, 2025
1 parent b099c3e commit c20fe61
Show file tree
Hide file tree
Showing 33 changed files with 286 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ withenv("GKSwstype" => "nul") do
tutorial in TUTORIALS
],
"Examples" => "examples.md",
"Debugging" => "debugging.md",
"Libraries" => [
joinpath("lib", "ClimaCorePlots.md"),
joinpath("lib", "ClimaCoreMakie.md"),
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,12 @@ InputOutput.defaultname
Remapping.interpolate_array
Remapping.interpolate
```

## DebugOnly

```@docs
DebugOnly
DebugOnly.call_post_op_callback
DebugOnly.post_op_callback
DebugOnly.example_debug_post_op_callback
```
49 changes: 49 additions & 0 deletions docs/src/debugging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Debugging

One of the most challenging tasks that users have is: debug a large simulation
that is breaking, e.g., yielding `NaN`s somewhere. This is especially complex
for large models with many terms and implicit time-stepping with all the bells
and whistles that the CliMA ecosystem offers.

ClimaCore has a module, [`ClimaCore.DebugOnly`](@ref), which contains tools for
debugging simulations for these complicated situations.

Because so much data (for example, the solution state, and many cached fields)
is typically contained in ClimaCore data structures, we offer a hook to inspect
this data after any operation that ClimaCore performs.

## Example

```@example
import ClimaCore
using ClimaCore: DataLayouts
ClimaCore.DebugOnly.call_post_op_callback() = true
function ClimaCore.DebugOnly.post_op_callback(result, args...; kwargs...)
if any(isnan, parent(data))
println("NaNs found!")
end
end
FT = Float64;
data = DataLayouts.VIJFH{FT}(Array{FT}, zeros; Nv=5, Nij=2, Nh=2)
@. data = NaN
```

Note that, due to dispatch, `post_op_callback` will likely need a very general
method signature, and using `post_op_callback
(result::DataLayouts.VIJFH, args...; kwargs...)` above fails (on the CPU),
because `post_op_callback` ends up getting called multiple times with different
datalayouts.

!!! warn

While this debugging tool may be helpful, it's not bullet proof. NaNs can
infiltrate user data any time internals are used. For example `parent
(data) .= NaN` will not be caught by ClimaCore.DebugOnly, and errors can be
observed later than expected.

!!! note

This method is called in many places, so this is a performance-critical code
path and expensive operations performed in `post_op_callback` may
significantly slow down your code.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import CUDA
using CUDA
using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DebugOnly: call_post_op_callback, post_op_callback
import ClimaCore.DataLayouts: mapreduce_cuda
import ClimaCore.DataLayouts: ToCUDA
import ClimaCore.DataLayouts: slab, column
Expand Down
18 changes: 12 additions & 6 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if VERSION ≥ v"1.11.0-beta"
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand All @@ -39,10 +39,11 @@ if VERSION ≥ v"1.11.0-beta"
blocks_s = p.blocks,
)
end
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
return dest
end
else
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand Down Expand Up @@ -74,6 +75,7 @@ else
)
end
end
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
return dest
end
end
Expand All @@ -85,7 +87,7 @@ end
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -95,13 +97,14 @@ function Base.copyto!(
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
end

# For field-vector operations
function DataLayouts.copyto_per_field!(
array::AbstractArray,
bc::Union{AbstractArray, Base.Broadcast.Broadcasted},
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, array, bc, to)
return array
end
function copyto_per_field_kernel!(array, bc, N)
Expand All @@ -133,7 +137,7 @@ end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -154,12 +158,13 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, array, bc, to)
return array
end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Real,
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, array, bc, to)
return array
end
function copyto_per_field_kernel_0D!(array, bc, N)
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function knl_fill_linear!(dest, val, us)
return nothing
end

function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
function Base.fill!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
args = (dest, bc, us)
Expand All @@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
)
end
end
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
return dest
end
11 changes: 9 additions & 2 deletions ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ function mapreduce_cuda(
)
pdata = parent(data)
S = eltype(data)
return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
call_post_op_callback() &&
post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...)
return data_out
end

function mapreduce_cuda(
Expand Down Expand Up @@ -101,7 +104,11 @@ function mapreduce_cuda(
Val(shmemsize),
)
end
return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))

call_post_op_callback() &&
post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...)
return data_out
end

function mapreduce_cuda_kernel!(
Expand Down
16 changes: 11 additions & 5 deletions ext/cuda/fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,52 @@ end

function Base.sum(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(identity, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[], field, dev)
return localsum[]
end

function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.sum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(fn, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[], fn, field, dev)
return localsum[]
end

function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(fn, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() && post_op_callback(localmax[], fn, field, dev)
return localmax[]
end

function Base.maximum(field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(identity, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() && post_op_callback(localmax[], fn, field, dev)
return localmax[]
end

function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.minimum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(fn, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() && post_op_callback(localmin[], fn, field, dev)
return localmin[]
end

function Base.minimum(field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(identity, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() && post_op_callback(localmin[], fn, field, dev)
return localmin[]
end

Expand Down
11 changes: 8 additions & 3 deletions ext/cuda/limiters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function compute_element_bounds!(
limiter::QuasiMonotoneLimiter,
ρq,
ρ,
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
ρ_values = Fields.field_values(Operators.strip_space(ρ, axes(ρ)))
ρq_values = Fields.field_values(Operators.strip_space(ρq, axes(ρq)))
Expand All @@ -33,6 +33,8 @@ function compute_element_bounds!(
threads_s = nthreads,
blocks_s = nblocks,
)
call_post_op_callback() &&
post_op_callback(limiter.q_bounds, limiter, ρq, ρ, dev)
return nothing
end

Expand Down Expand Up @@ -70,7 +72,7 @@ end
function compute_neighbor_bounds_local!(
limiter::QuasiMonotoneLimiter,
ρ,
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
topology = Spaces.topology(axes(ρ))
us = DataLayouts.UniversalSize(Fields.field_values(ρ))
Expand All @@ -88,6 +90,8 @@ function compute_neighbor_bounds_local!(
threads_s = nthreads,
blocks_s = nblocks,
)
call_post_op_callback() &&
post_op_callback(limiter.q_bounds, limiter, ρ, dev)
end

function compute_neighbor_bounds_local_kernel!(
Expand Down Expand Up @@ -123,7 +127,7 @@ function apply_limiter!(
ρq::Fields.Field,
ρ::Fields.Field,
limiter::QuasiMonotoneLimiter,
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
ρq_data = Fields.field_values(ρq)
us = DataLayouts.UniversalSize(ρq_data)
Expand All @@ -147,6 +151,7 @@ function apply_limiter!(
threads_s = nthreads,
blocks_s = nblocks,
)
call_post_op_callback() && post_op_callback(ρq, ρq, ρ, limiter, dev)
return nothing
end

Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ClimaCore.Utilities.UnrolledFunctions: unrolled_map
is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true

NVTX.@annotate function multiple_field_solve!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
cache,
x,
A,
Expand Down Expand Up @@ -48,6 +48,7 @@ NVTX.@annotate function multiple_field_solve!(
blocks_s = p.blocks,
always_inline = true,
)
call_post_op_callback() && post_op_callback(x, dev, cache, x, A, b, x1)
end

Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
end

function single_field_solve_kernel!(device, cache, x, A, b, us)
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, out, bc)
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
Expand Down
7 changes: 6 additions & 1 deletion ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ClimaComms
using CUDA: @cuda

function column_reduce_device!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
f::F,
transform::T,
output,
Expand Down Expand Up @@ -40,6 +40,11 @@ function column_reduce_device!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(
output,
(dev, f, transform, output, input, init, space),
(;),
)
end

function column_accumulate_device!(
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, out, sbc)
return out
end

Expand Down
Loading

0 comments on commit c20fe61

Please sign in to comment.