From c20fe61f915cb0893f5c3fb60811114e7d2a7ccd Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 3 Jan 2025 16:07:45 -0500 Subject: [PATCH] Introduce post-operation callback --- docs/make.jl | 1 + docs/src/api.md | 9 +++ docs/src/debugging.md | 49 +++++++++++++++ ext/ClimaCoreCUDAExt.jl | 1 + ext/cuda/data_layouts_copyto.jl | 18 ++++-- ext/cuda/data_layouts_fill.jl | 3 +- ext/cuda/data_layouts_mapreduce.jl | 11 +++- ext/cuda/fields.jl | 16 +++-- ext/cuda/limiters.jl | 11 +++- .../matrix_fields_multiple_field_solve.jl | 3 +- ext/cuda/matrix_fields_single_field_solve.jl | 1 + ext/cuda/operators_finite_difference.jl | 1 + ext/cuda/operators_integral.jl | 7 ++- ext/cuda/operators_spectral_element.jl | 1 + ext/cuda/operators_thomas_algorithm.jl | 3 +- ext/cuda/remapping_distributed.jl | 13 +++- ext/cuda/remapping_interpolate_array.jl | 8 +++ ext/cuda/topologies_dss.jl | 62 ++++++++++++++++--- src/ClimaCore.jl | 1 + src/DataLayouts/DataLayouts.jl | 1 + src/DataLayouts/copyto.jl | 9 ++- src/DataLayouts/fill.jl | 2 + src/DebugOnly/DebugOnly.jl | 49 +++++++++++++++ src/Fields/Fields.jl | 1 + src/Fields/mapreduce.jl | 22 ++++--- src/Limiters/Limiters.jl | 1 + src/Limiters/quasimonotone.jl | 15 ++++- src/Operators/Operators.jl | 1 + src/Operators/finitedifference.jl | 4 ++ src/Operators/spectralelement.jl | 1 + src/Operators/thomas_algorithm.jl | 1 + src/Spaces/Spaces.jl | 1 + src/Spaces/dss.jl | 1 + 33 files changed, 286 insertions(+), 42 deletions(-) create mode 100644 docs/src/debugging.md create mode 100644 src/DebugOnly/DebugOnly.jl diff --git a/docs/make.jl b/docs/make.jl index 66bed38583..44417a2e18 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -86,6 +86,7 @@ withenv("GKSwstype" => "nul") do tutorial in TUTORIALS ], "Examples" => "examples.md", + "Debugging" => "debugging.md", "Libraries" => [ joinpath("lib", "ClimaCorePlots.md"), joinpath("lib", "ClimaCoreMakie.md"), diff --git a/docs/src/api.md b/docs/src/api.md index 1c93d0dd3b..c20022a44a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -401,3 +401,12 @@ InputOutput.defaultname Remapping.interpolate_array Remapping.interpolate ``` + +## DebugOnly + +```@docs +DebugOnly +DebugOnly.call_post_op_callback +DebugOnly.post_op_callback +DebugOnly.example_debug_post_op_callback +``` diff --git a/docs/src/debugging.md b/docs/src/debugging.md new file mode 100644 index 0000000000..dd571aceac --- /dev/null +++ b/docs/src/debugging.md @@ -0,0 +1,49 @@ +# Debugging + +One of the most challenging tasks that users have is: debug a large simulation +that is breaking, e.g., yielding `NaN`s somewhere. This is especially complex +for large models with many terms and implicit time-stepping with all the bells +and whistles that the CliMA ecosystem offers. + +ClimaCore has a module, [`ClimaCore.DebugOnly`](@ref), which contains tools for +debugging simulations for these complicated situations. + +Because so much data (for example, the solution state, and many cached fields) +is typically contained in ClimaCore data structures, we offer a hook to inspect +this data after any operation that ClimaCore performs. + +## Example + +```@example +import ClimaCore +using ClimaCore: DataLayouts +ClimaCore.DebugOnly.call_post_op_callback() = true +function ClimaCore.DebugOnly.post_op_callback(result, args...; kwargs...) + if any(isnan, parent(data)) + println("NaNs found!") + end +end + +FT = Float64; +data = DataLayouts.VIJFH{FT}(Array{FT}, zeros; Nv=5, Nij=2, Nh=2) +@. data = NaN +``` + +Note that, due to dispatch, `post_op_callback` will likely need a very general +method signature, and using `post_op_callback +(result::DataLayouts.VIJFH, args...; kwargs...)` above fails (on the CPU), +because `post_op_callback` ends up getting called multiple times with different +datalayouts. + +!!! warn + + While this debugging tool may be helpful, it's not bullet proof. NaNs can + infiltrate user data any time internals are used. For example `parent + (data) .= NaN` will not be caught by ClimaCore.DebugOnly, and errors can be + observed later than expected. + +!!! note + + This method is called in many places, so this is a performance-critical code + path and expensive operations performed in `post_op_callback` may + significantly slow down your code. diff --git a/ext/ClimaCoreCUDAExt.jl b/ext/ClimaCoreCUDAExt.jl index 8952a6b24a..63bbe39be9 100644 --- a/ext/ClimaCoreCUDAExt.jl +++ b/ext/ClimaCoreCUDAExt.jl @@ -9,6 +9,7 @@ import CUDA using CUDA using CUDA: threadIdx, blockIdx, blockDim import StaticArrays: SVector, SMatrix, SArray +import ClimaCore.DebugOnly: call_post_op_callback, post_op_callback import ClimaCore.DataLayouts: mapreduce_cuda import ClimaCore.DataLayouts: ToCUDA import ClimaCore.DataLayouts: slab, column diff --git a/ext/cuda/data_layouts_copyto.jl b/ext/cuda/data_layouts_copyto.jl index eb07e2bdc9..ecba16b1ad 100644 --- a/ext/cuda/data_layouts_copyto.jl +++ b/ext/cuda/data_layouts_copyto.jl @@ -24,7 +24,7 @@ if VERSION ≥ v"1.11.0-beta" # special-case fixes for https://github.com/JuliaLang/julia/issues/28126 # (including the GPU-variant related issue resolution efforts: # JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464). - function Base.copyto!(dest::AbstractData, bc, ::ToCUDA) + function Base.copyto!(dest::AbstractData, bc, to::ToCUDA) (_, _, Nv, _, Nh) = DataLayouts.universal_size(dest) us = DataLayouts.UniversalSize(dest) if Nv > 0 && Nh > 0 @@ -39,10 +39,11 @@ if VERSION ≥ v"1.11.0-beta" blocks_s = p.blocks, ) end + call_post_op_callback() && post_op_callback(dest, dest, bc, to) return dest end else - function Base.copyto!(dest::AbstractData, bc, ::ToCUDA) + function Base.copyto!(dest::AbstractData, bc, to::ToCUDA) (_, _, Nv, _, Nh) = DataLayouts.universal_size(dest) us = DataLayouts.UniversalSize(dest) if Nv > 0 && Nh > 0 @@ -74,6 +75,7 @@ else ) end end + call_post_op_callback() && post_op_callback(dest, dest, bc, to) return dest end end @@ -85,7 +87,7 @@ end function Base.copyto!( dest::AbstractData, bc::Base.Broadcast.Broadcasted{Style}, - ::ToCUDA, + to::ToCUDA, ) where { Style <: Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}}, @@ -95,13 +97,14 @@ function Base.copyto!( ) @inbounds bc0 = bc[] fill!(dest, bc0) + call_post_op_callback() && post_op_callback(dest, dest, bc, to) end # For field-vector operations function DataLayouts.copyto_per_field!( array::AbstractArray, bc::Union{AbstractArray, Base.Broadcast.Broadcasted}, - ::ToCUDA, + to::ToCUDA, ) bc′ = DataLayouts.to_non_extruded_broadcasted(bc) # All field variables are treated separately, so @@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(array, array, bc, to) return array end function copyto_per_field_kernel!(array, bc, N) @@ -133,7 +137,7 @@ end function DataLayouts.copyto_per_field_scalar!( array::AbstractArray, bc::Base.Broadcast.Broadcasted{Style}, - ::ToCUDA, + to::ToCUDA, ) where { Style <: Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}}, @@ -154,12 +158,13 @@ function DataLayouts.copyto_per_field_scalar!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(array, array, bc, to) return array end function DataLayouts.copyto_per_field_scalar!( array::AbstractArray, bc::Real, - ::ToCUDA, + to::ToCUDA, ) bc′ = DataLayouts.to_non_extruded_broadcasted(bc) # All field variables are treated separately, so @@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(array, array, bc, to) return array end function copyto_per_field_kernel_0D!(array, bc, N) diff --git a/ext/cuda/data_layouts_fill.jl b/ext/cuda/data_layouts_fill.jl index 4e2764789a..92abf97088 100644 --- a/ext/cuda/data_layouts_fill.jl +++ b/ext/cuda/data_layouts_fill.jl @@ -14,7 +14,7 @@ function knl_fill_linear!(dest, val, us) return nothing end -function Base.fill!(dest::AbstractData, bc, ::ToCUDA) +function Base.fill!(dest::AbstractData, bc, to::ToCUDA) (_, _, Nv, _, Nh) = DataLayouts.universal_size(dest) us = DataLayouts.UniversalSize(dest) args = (dest, bc, us) @@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA) ) end end + call_post_op_callback() && post_op_callback(dest, dest, bc, to) return dest end diff --git a/ext/cuda/data_layouts_mapreduce.jl b/ext/cuda/data_layouts_mapreduce.jl index 6bae9f0c28..551b8d451d 100644 --- a/ext/cuda/data_layouts_mapreduce.jl +++ b/ext/cuda/data_layouts_mapreduce.jl @@ -15,7 +15,10 @@ function mapreduce_cuda( ) pdata = parent(data) S = eltype(data) - return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :])) + data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :])) + call_post_op_callback() && + post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...) + return data_out end function mapreduce_cuda( @@ -101,7 +104,11 @@ function mapreduce_cuda( Val(shmemsize), ) end - return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :])) + data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :])) + + call_post_op_callback() && + post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...) + return data_out end function mapreduce_cuda_kernel!( diff --git a/ext/cuda/fields.jl b/ext/cuda/fields.jl index aa9d7037a3..abfa3bbf04 100644 --- a/ext/cuda/fields.jl +++ b/ext/cuda/fields.jl @@ -13,39 +13,44 @@ end function Base.sum( field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}, - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, ) context = ClimaComms.context(axes(field)) localsum = mapreduce_cuda(identity, +, field, weighting = true) ClimaComms.allreduce!(context, parent(localsum), +) + call_post_op_callback() && post_op_callback(localsum[], field, dev) return localsum[] end -function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice) +function Base.sum(fn, field::Field, dev::ClimaComms.CUDADevice) context = ClimaComms.context(axes(field)) localsum = mapreduce_cuda(fn, +, field, weighting = true) ClimaComms.allreduce!(context, parent(localsum), +) + call_post_op_callback() && post_op_callback(localsum[], fn, field, dev) return localsum[] end -function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice) +function Base.maximum(fn, field::Field, dev::ClimaComms.CUDADevice) context = ClimaComms.context(axes(field)) localmax = mapreduce_cuda(fn, max, field) ClimaComms.allreduce!(context, parent(localmax), max) + call_post_op_callback() && post_op_callback(localmax[], fn, field, dev) return localmax[] end -function Base.maximum(field::Field, ::ClimaComms.CUDADevice) +function Base.maximum(field::Field, dev::ClimaComms.CUDADevice) context = ClimaComms.context(axes(field)) localmax = mapreduce_cuda(identity, max, field) ClimaComms.allreduce!(context, parent(localmax), max) + call_post_op_callback() && post_op_callback(localmax[], fn, field, dev) return localmax[] end -function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice) +function Base.minimum(fn, field::Field, dev::ClimaComms.CUDADevice) context = ClimaComms.context(axes(field)) localmin = mapreduce_cuda(fn, min, field) ClimaComms.allreduce!(context, parent(localmin), min) + call_post_op_callback() && post_op_callback(localmin[], fn, field, dev) return localmin[] end @@ -53,6 +58,7 @@ function Base.minimum(field::Field, ::ClimaComms.CUDADevice) context = ClimaComms.context(axes(field)) localmin = mapreduce_cuda(identity, min, field) ClimaComms.allreduce!(context, parent(localmin), min) + call_post_op_callback() && post_op_callback(localmin[], fn, field, dev) return localmin[] end diff --git a/ext/cuda/limiters.jl b/ext/cuda/limiters.jl index a7dd6e393a..33f33e1841 100644 --- a/ext/cuda/limiters.jl +++ b/ext/cuda/limiters.jl @@ -19,7 +19,7 @@ function compute_element_bounds!( limiter::QuasiMonotoneLimiter, ρq, ρ, - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, ) ρ_values = Fields.field_values(Operators.strip_space(ρ, axes(ρ))) ρq_values = Fields.field_values(Operators.strip_space(ρq, axes(ρq))) @@ -33,6 +33,8 @@ function compute_element_bounds!( threads_s = nthreads, blocks_s = nblocks, ) + call_post_op_callback() && + post_op_callback(limiter.q_bounds, limiter, ρq, ρ, dev) return nothing end @@ -70,7 +72,7 @@ end function compute_neighbor_bounds_local!( limiter::QuasiMonotoneLimiter, ρ, - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, ) topology = Spaces.topology(axes(ρ)) us = DataLayouts.UniversalSize(Fields.field_values(ρ)) @@ -88,6 +90,8 @@ function compute_neighbor_bounds_local!( threads_s = nthreads, blocks_s = nblocks, ) + call_post_op_callback() && + post_op_callback(limiter.q_bounds, limiter, ρ, dev) end function compute_neighbor_bounds_local_kernel!( @@ -123,7 +127,7 @@ function apply_limiter!( ρq::Fields.Field, ρ::Fields.Field, limiter::QuasiMonotoneLimiter, - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, ) ρq_data = Fields.field_values(ρq) us = DataLayouts.UniversalSize(ρq_data) @@ -147,6 +151,7 @@ function apply_limiter!( threads_s = nthreads, blocks_s = nblocks, ) + call_post_op_callback() && post_op_callback(ρq, ρq, ρ, limiter, dev) return nothing end diff --git a/ext/cuda/matrix_fields_multiple_field_solve.jl b/ext/cuda/matrix_fields_multiple_field_solve.jl index 3955aabaa7..76c0b6e3eb 100644 --- a/ext/cuda/matrix_fields_multiple_field_solve.jl +++ b/ext/cuda/matrix_fields_multiple_field_solve.jl @@ -11,7 +11,7 @@ import ClimaCore.Utilities.UnrolledFunctions: unrolled_map is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true NVTX.@annotate function multiple_field_solve!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, cache, x, A, @@ -48,6 +48,7 @@ NVTX.@annotate function multiple_field_solve!( blocks_s = p.blocks, always_inline = true, ) + call_post_op_callback() && post_op_callback(x, dev, cache, x, A, b, x1) end Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index b486ef9041..40cbee3ebe 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -27,6 +27,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(x, device, cache, x, A, b) end function single_field_solve_kernel!(device, cache, x, A, b, us) diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index c93b4e0797..870de35083 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -34,6 +34,7 @@ function Base.copyto!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(out, out, bc) return out end import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh diff --git a/ext/cuda/operators_integral.jl b/ext/cuda/operators_integral.jl index 7b344b511e..db1bccd727 100644 --- a/ext/cuda/operators_integral.jl +++ b/ext/cuda/operators_integral.jl @@ -10,7 +10,7 @@ import ClimaComms using CUDA: @cuda function column_reduce_device!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, f::F, transform::T, output, @@ -40,6 +40,11 @@ function column_reduce_device!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + output, + (dev, f, transform, output, input, init, space), + (;), + ) end function column_accumulate_device!( diff --git a/ext/cuda/operators_spectral_element.jl b/ext/cuda/operators_spectral_element.jl index 658b3440e6..d22ce173a7 100644 --- a/ext/cuda/operators_spectral_element.jl +++ b/ext/cuda/operators_spectral_element.jl @@ -49,6 +49,7 @@ function Base.copyto!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(out, out, sbc) return out end diff --git a/ext/cuda/operators_thomas_algorithm.jl b/ext/cuda/operators_thomas_algorithm.jl index d25e26fd3d..236f72b5b1 100644 --- a/ext/cuda/operators_thomas_algorithm.jl +++ b/ext/cuda/operators_thomas_algorithm.jl @@ -4,7 +4,7 @@ import ClimaCore.Operators: column_thomas_solve!, thomas_algorithm_kernel!, thomas_algorithm! import CUDA using CUDA: @cuda -function column_thomas_solve!(::ClimaComms.CUDADevice, A, b) +function column_thomas_solve!(dev::ClimaComms.CUDADevice, A, b) us = UniversalSize(Fields.field_values(A)) args = (A, b, us) Ni, Nj, _, _, Nh = size(Fields.field_values(A)) @@ -18,6 +18,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b) threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback(b, dev, A, b) end function thomas_algorithm_kernel!( diff --git a/ext/cuda/remapping_distributed.jl b/ext/cuda/remapping_distributed.jl index eb88896fa2..fde7b4a13c 100644 --- a/ext/cuda/remapping_distributed.jl +++ b/ext/cuda/remapping_distributed.jl @@ -13,7 +13,7 @@ function _set_interpolated_values_device!( interpolation_matrix, vert_interpolation_weights::AbstractArray, vert_bounding_indices::AbstractArray, - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, ) # FIXME: Avoid allocation of tuple field_values = tuple(map(f -> Fields.field_values(f), fields)...) @@ -39,6 +39,17 @@ function _set_interpolated_values_device!( threads_s = (nthreads), blocks_s = (nblocks), ) + call_post_op_callback() && post_op_callback( + out, + out, + fields, + scratch_field_values, + local_horiz_indices, + interpolation_matrix, + vert_interpolation_weights, + vert_bounding_indices, + dev, + ) end # GPU, 3D case diff --git a/ext/cuda/remapping_interpolate_array.jl b/ext/cuda/remapping_interpolate_array.jl index d96862c679..622efc70fb 100644 --- a/ext/cuda/remapping_interpolate_array.jl +++ b/ext/cuda/remapping_interpolate_array.jl @@ -26,6 +26,14 @@ function interpolate_slab!( threads_s = (nthreads), blocks_s = (nblocks), ) + call_post_op_callback() && post_op_callback( + output_array, + output_array, + field, + slab_indices, + weights, + device, + ) output_array .= Array(output_cuarray) end diff --git a/ext/cuda/topologies_dss.jl b/ext/cuda/topologies_dss.jl index f70cd48956..5662e68d5e 100644 --- a/ext/cuda/topologies_dss.jl +++ b/ext/cuda/topologies_dss.jl @@ -17,7 +17,7 @@ _configure_threadblock(nitems) = _configure_threadblock(_max_threads_cuda(), nitems) function Topologies.dss_load_perimeter_data!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, dss_buffer::Topologies.DSSBuffer, data::DSSDataTypes, perimeter::Topologies.Perimeter2D, @@ -33,6 +33,8 @@ function Topologies.dss_load_perimeter_data!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && + post_op_callback(perimeter_data, dev, dss_buffer, data, perimeter) return nothing end @@ -57,7 +59,7 @@ function dss_load_perimeter_data_kernel!( end function Topologies.dss_unload_perimeter_data!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, data::DSSDataTypes, dss_buffer::Topologies.DSSBuffer, perimeter, @@ -73,6 +75,8 @@ function Topologies.dss_unload_perimeter_data!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && + post_op_callback(data, dev, data, dss_buffer, perimeter) return nothing end @@ -97,7 +101,7 @@ function dss_unload_perimeter_data_kernel!( end function Topologies.dss_local!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, perimeter_data::DSSPerimeterTypes, perimeter::Topologies.Perimeter2D, topology::Topologies.Topology2D, @@ -123,6 +127,13 @@ function Topologies.dss_local!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + perimeter_data, + dev, + perimeter_data, + perimeter, + topology, + ) end return nothing end @@ -213,6 +224,16 @@ function Topologies.dss_transform!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + perimeter_data, + device, + perimeter_data, + data, + perimeter, + local_geometry, + weight, + localelems, + ) end return nothing end @@ -276,6 +297,15 @@ function Topologies.dss_untransform!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + data, + device, + perimeter_data, + data, + local_geometry, + perimeter, + localelems, + ) end return nothing end @@ -309,7 +339,7 @@ end # TODO: Function stubs, code to be implemented, needed only for distributed GPU runs function Topologies.dss_local_ghost!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, perimeter_data::DSSPerimeterTypes, perimeter::Topologies.Perimeter2D, topology::Topologies.AbstractTopology, @@ -333,6 +363,13 @@ function Topologies.dss_local_ghost!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + perimeter_data, + dev, + perimeter_data, + perimeter, + topology, + ) end return nothing end @@ -374,7 +411,7 @@ function dss_local_ghost_kernel!( end function Topologies.fill_send_buffer!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, dss_buffer::Topologies.DSSBuffer; synchronize = true, ) @@ -396,6 +433,8 @@ function Topologies.fill_send_buffer!( if synchronize CUDA.synchronize(; blocking = true) # CUDA MPI uses a separate stream. This will synchronize across streams end + call_post_op_callback() && + post_op_callback(send_data, dev, dss_buffer; synchronize) end return nothing end @@ -422,7 +461,7 @@ function fill_send_buffer_kernel!( end function Topologies.load_from_recv_buffer!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, dss_buffer::Topologies.DSSBuffer, ) (; perimeter_data, recv_buf_idx, recv_data) = dss_buffer @@ -440,6 +479,8 @@ function Topologies.load_from_recv_buffer!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && + post_op_callback(perimeter_data, dev, dss_buffer) end return nothing end @@ -474,7 +515,7 @@ end function Topologies.dss_ghost!( - ::ClimaComms.CUDADevice, + dev::ClimaComms.CUDADevice, perimeter_data::DSSPerimeterTypes, perimeter::Topologies.Perimeter2D, topology::Topologies.Topology2D, @@ -499,6 +540,13 @@ function Topologies.dss_ghost!( threads_s = p.threads, blocks_s = p.blocks, ) + call_post_op_callback() && post_op_callback( + perimeter_data, + dev, + perimeter_data, + perimeter, + topology, + ) end return nothing end diff --git a/src/ClimaCore.jl b/src/ClimaCore.jl index 2ee050abb2..42e189bd77 100644 --- a/src/ClimaCore.jl +++ b/src/ClimaCore.jl @@ -4,6 +4,7 @@ using PkgVersion const VERSION = PkgVersion.@Version import ClimaComms +include("DebugOnly/DebugOnly.jl") include("interface.jl") include("devices.jl") include("Utilities/Utilities.jl") diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index d5314d1da1..4e74d44a08 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -69,6 +69,7 @@ import ClimaComms import MultiBroadcastFusion as MBF import Adapt +import ..DebugOnly: call_post_op_callback, post_op_callback import ..slab, ..slab_args, ..column, ..column_args, ..level export slab, column, diff --git a/src/DataLayouts/copyto.jl b/src/DataLayouts/copyto.jl index 527083f1a8..3a2b9242ae 100644 --- a/src/DataLayouts/copyto.jl +++ b/src/DataLayouts/copyto.jl @@ -13,7 +13,9 @@ if VERSION ≥ v"1.11.0-beta" dest::AbstractData{S}, bc::Union{AbstractData, Base.Broadcast.Broadcasted}, ) where {S} - return Base.copyto!(dest, bc, device_dispatch(parent(dest))) + Base.copyto!(dest, bc, device_dispatch(parent(dest))) + call_post_op_callback() && post_op_callback(dest, dest, bc) + dest end else function Base.copyto!( @@ -33,6 +35,7 @@ else else Base.copyto!(dest, bc, device_dispatch(parent(dest))) end + call_post_op_callback() && post_op_callback(dest, dest, bc) return dest end end @@ -40,6 +43,7 @@ end # Specialize on non-Broadcasted objects function Base.copyto!(dest::D, src::D) where {D <: AbstractData} copyto!(parent(dest), parent(src)) + call_post_op_callback() && post_op_callback(dest, dest, src) return dest end @@ -48,7 +52,7 @@ end function Base.copyto!( dest::AbstractData, bc::Base.Broadcast.Broadcasted{Style}, - ::AbstractDispatchToDevice, + to::AbstractDispatchToDevice, ) where { Style <: Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}}, @@ -58,6 +62,7 @@ function Base.copyto!( ) @inbounds bc0 = bc[] fill!(dest, bc0) + call_post_op_callback() && post_op_callback(dest, dest, bc, to) end ##### diff --git a/src/DataLayouts/fill.jl b/src/DataLayouts/fill.jl index fff2606cd1..5c729fe577 100644 --- a/src/DataLayouts/fill.jl +++ b/src/DataLayouts/fill.jl @@ -9,6 +9,8 @@ function Base.fill!(dest::AbstractData, val) else Base.fill!(dest, val, dev) end + call_post_op_callback() && post_op_callback(dest, dest, val) + dest end function Base.fill!(data::Union{IJFH, IJHF}, val, ::ToCPU) diff --git a/src/DebugOnly/DebugOnly.jl b/src/DebugOnly/DebugOnly.jl new file mode 100644 index 0000000000..5c2349b3d6 --- /dev/null +++ b/src/DebugOnly/DebugOnly.jl @@ -0,0 +1,49 @@ +""" + DebugOnly + +A module for debugging tools. Note that any tools in here are subject to sudden +changes without warning. So, please, do _not_ use any of these tools in +production as support for them are not guaranteed. +""" +module DebugOnly + +""" + post_op_callback(result, args...; kwargs...) + +A callback that is called, if `ClimaCore.DataLayouts.call_post_op_callback() = +true`, on the result of every data operation. + +There is purposely no implementation-- this is a debugging tool, and users may +want to check different things. + +Note that, since this method is called in so many places, this is a +performance-critical code path and expensive operations performed in +`post_op_callback` may significantly slow down your code. +""" +function post_op_callback end + +""" + call_post_op_callback() + +Returns a Bool. Meant to be overloaded so that +`ClimaCore.DataLayouts.post_op_callback(result, args...; kwargs...)` is called +after every data operation. +""" +call_post_op_callback() = false + +# TODO: define a convenience macro to inject `post_op_hook` + +""" + example_debug_post_op_callback(result, args...; kwargs...) + +An example `post_op_callback` method, that checks for `NaN`s and `Inf`s. +""" +function example_debug_post_op_callback(result, args...; kwargs...) + if any(isnan, parent(result)) + error("NaNs found!") + elseif any(isinf, parent(result)) + error("Inf found!") + end +end + +end diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index be35a74835..9e687f1bd8 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -3,6 +3,7 @@ module Fields import ClimaComms import MultiBroadcastFusion as MBF import ..slab, ..slab_args, ..column, ..column_args, ..level +import ..DebugOnly: call_post_op_callback, post_op_callback import ..DataLayouts: DataLayouts, AbstractData, diff --git a/src/Fields/mapreduce.jl b/src/Fields/mapreduce.jl index 4571cde60c..caba82935f 100644 --- a/src/Fields/mapreduce.jl +++ b/src/Fields/mapreduce.jl @@ -8,17 +8,21 @@ context. See [`sum`](@ref) for the integral over the full domain. """ -local_sum( +function local_sum( field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}, - ::ClimaComms.AbstractCPUDevice, -) = Base.reduce( - RecursiveApply.radd, - Base.Broadcast.broadcasted( - RecursiveApply.rmul, - Spaces.weighted_jacobian(axes(field)), - todata(field), - ), + dev::ClimaComms.AbstractCPUDevice, ) + result = Base.reduce( + RecursiveApply.radd, + Base.Broadcast.broadcasted( + RecursiveApply.rmul, + Spaces.weighted_jacobian(axes(field)), + todata(field), + ), + ) + call_post_op_callback() && post_op_callback(result, field, dev) + result +end local_sum(field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}) = local_sum(field, ClimaComms.device(axes(field))) """ diff --git a/src/Limiters/Limiters.jl b/src/Limiters/Limiters.jl index 55b8deb36c..bd9116cf0d 100644 --- a/src/Limiters/Limiters.jl +++ b/src/Limiters/Limiters.jl @@ -2,6 +2,7 @@ module Limiters import ..DataLayouts, ..Topologies, ..Spaces, ..Fields import ..RecursiveApply: rdiv, rmin, rmax +import ..DebugOnly: call_post_op_callback, post_op_callback import ClimaCore: slab export AbstractLimiter diff --git a/src/Limiters/quasimonotone.jl b/src/Limiters/quasimonotone.jl index b038cd91de..ad19350937 100644 --- a/src/Limiters/quasimonotone.jl +++ b/src/Limiters/quasimonotone.jl @@ -113,7 +113,7 @@ function compute_element_bounds!( limiter::QuasiMonotoneLimiter, ρq, ρ, - ::ClimaComms.AbstractCPUDevice, + dev::ClimaComms.AbstractCPUDevice, ) ρ_data = Fields.field_values(ρ) ρq_data = Fields.field_values(ρq) @@ -144,6 +144,8 @@ function compute_element_bounds!( slab_q_bounds[slab_index(2)] = q_max end end + call_post_op_callback() && + post_op_callback(limiter.q_bounds, limiter, ρq, ρ, dev) return nothing end @@ -161,7 +163,7 @@ compute_neighbor_bounds_local!(limiter::QuasiMonotoneLimiter, ρ) = function compute_neighbor_bounds_local!( limiter::QuasiMonotoneLimiter, ρ, - ::ClimaComms.AbstractCPUDevice, + dev::ClimaComms.AbstractCPUDevice, ) topology = Spaces.topology(axes(ρ)) q_bounds = limiter.q_bounds @@ -182,6 +184,8 @@ function compute_neighbor_bounds_local!( slab_q_bounds_nbr[slab_index(2)] = q_max end end + call_post_op_callback() && + post_op_callback(limiter.q_bounds_nbr, limiter, ρ, dev) return nothing end @@ -218,6 +222,8 @@ function compute_neighbor_bounds_ghost!( end end end + call_post_op_callback() && + post_op_callback(limiter.q_bounds_nbr, limiter, topology) return nothing end @@ -253,6 +259,8 @@ function compute_bounds!( ClimaComms.finish(limiter.ghost_buffer.graph_context) compute_neighbor_bounds_ghost!(limiter, Spaces.topology(axes(ρq))) end + call_post_op_callback() && + post_op_callback(limiter.q_bounds, limiter, ρq, ρ) end @@ -276,7 +284,7 @@ function apply_limiter!( ρq::Fields.Field, ρ::Fields.Field, limiter::QuasiMonotoneLimiter, - ::ClimaComms.AbstractCPUDevice, + dev::ClimaComms.AbstractCPUDevice, ) (; q_bounds_nbr, rtol) = limiter @@ -302,6 +310,7 @@ function apply_limiter!( converged || @warn "Limiter failed to converge with rtol = $rtol, `max_rel_err`=$max_rel_err" + call_post_op_callback() && post_op_callback(ρq, ρq, ρ, limiter, dev) return ρq end diff --git a/src/Operators/Operators.jl b/src/Operators/Operators.jl index 3f79c8fb49..aa0621936d 100644 --- a/src/Operators/Operators.jl +++ b/src/Operators/Operators.jl @@ -8,6 +8,7 @@ import Base.Broadcast: Broadcasted import ..slab, ..slab_args, ..column, ..column_args import ClimaComms +import ..DebugOnly: call_post_op_callback, post_op_callback import ..DataLayouts: DataLayouts, Data2D, DataSlab2D import ..DataLayouts: vindex import ..Geometry: Geometry, Covariant12Vector, Contravariant12Vector, ⊗ diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 1ed79f430a..1b909e1389 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3779,6 +3779,8 @@ function _serial_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int) @inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni apply_stencil!(space, field_out, bcs, (i, j, h), bounds) end + call_post_op_callback() && + post_op_callback(field_out, field_out, bc, Ni, Nj, Nh) return field_out end @@ -3793,6 +3795,8 @@ function _threaded_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int) end end end + call_post_op_callback() && + post_op_callback(field_out, field_out, bc, Ni, Nj, Nh) return field_out end diff --git a/src/Operators/spectralelement.jl b/src/Operators/spectralelement.jl index 39b3eef714..ceefa850a6 100644 --- a/src/Operators/spectralelement.jl +++ b/src/Operators/spectralelement.jl @@ -168,6 +168,7 @@ function Base.copyto!( Base.@_inline_meta @inbounds copyto_slab!(out, sbc, slabidx) end + call_post_op_callback() && post_op_callback(out, out, sbc) return out end diff --git a/src/Operators/thomas_algorithm.jl b/src/Operators/thomas_algorithm.jl index 3ea18cb1ff..79414dcd64 100644 --- a/src/Operators/thomas_algorithm.jl +++ b/src/Operators/thomas_algorithm.jl @@ -27,6 +27,7 @@ function thomas_algorithm!( Fields.bycolumn(axes(A)) do colidx thomas_algorithm!(A[colidx], b[colidx]) end + call_post_op_callback() && post_op_callback(b, A, b) end function thomas_algorithm!( diff --git a/src/Spaces/Spaces.jl b/src/Spaces/Spaces.jl index 730df96f71..9c955b5266 100644 --- a/src/Spaces/Spaces.jl +++ b/src/Spaces/Spaces.jl @@ -19,6 +19,7 @@ using Adapt import ..slab, ..column, ..level import ..Utilities: PlusHalf, half +import ..DebugOnly: call_post_op_callback, post_op_callback import ..DataLayouts, ..Geometry, ..Domains, ..Meshes, ..Topologies, ..Grids, ..Quadratures diff --git a/src/Spaces/dss.jl b/src/Spaces/dss.jl index 21420f8c1d..53b99cf9dd 100644 --- a/src/Spaces/dss.jl +++ b/src/Spaces/dss.jl @@ -97,6 +97,7 @@ function weighted_dss!( weighted_dss_start!(data, space, dss_buffer) weighted_dss_internal!(data, space, dss_buffer) weighted_dss_ghost!(data, space, dss_buffer) + call_post_op_callback() && post_op_callback(data, data, space, dss_buffer) end