diff --git a/Project.toml b/Project.toml index f6ab937a6..3636961ce 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +CudaExt = "CUDA" + [compat] ClimaComms = "0.4, 0.5" Colors = "0.12" diff --git a/ext/CudaExt.jl b/ext/CudaExt.jl new file mode 100644 index 000000000..333c5358f --- /dev/null +++ b/ext/CudaExt.jl @@ -0,0 +1,60 @@ +module CudaExt + +import CUDA +import ClimaComms: SingletonCommsContext, CUDADevice +import ClimaTimeSteppers: compute_T_lim_T_exp! + +@inline function compute_T_lim_T_exp!(T_lim, T_exp, U, p, t, T_lim!, T_exp!, ::SingletonCommsContext{CUDADevice}) + # TODO: we should benchmark these two options to + # see if one is preferrable over the other + if Base.Threads.nthreads() > 1 + compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) + else + compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) + end +end + +@inline function compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) + event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + CUDA.record(event, CUDA.stream()) # record event on main stream + + stream1 = CUDA.CuStream() # make a stream + local event1 + CUDA.stream!(stream1) do # work to be done by stream1 + CUDA.wait(event, stream1) # make stream1 wait on event (host continues) + T_lim!(T_lim, U, p, t) + event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + end + CUDA.record(event1, stream1) # record event1 on stream1 + + stream2 = CUDA.CuStream() # make a stream + local event2 + CUDA.stream!(stream2) do # work to be done by stream2 + CUDA.wait(event, stream2) # make stream2 wait on event (host continues) + T_exp!(T_exp, U, p, t) + event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) + end + CUDA.record(event2, stream2) # record event2 on stream2 + + CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1 + CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2 +end + +@inline function compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) + + CUDA.synchronize() + CUDA.@sync begin + Base.Threads.@spawn begin + T_lim!(T_lim, U, p, t) + CUDA.synchronize() + nothing + end + Base.Threads.@spawn begin + T_exp!(T_exp, U, p, t) + CUDA.synchronize() + nothing + end + end +end + +end diff --git a/src/solvers/compute_T_exp_T_lim.jl b/src/solvers/compute_T_exp_T_lim.jl index d2758b147..d4a6b218a 100644 --- a/src/solvers/compute_T_exp_T_lim.jl +++ b/src/solvers/compute_T_exp_T_lim.jl @@ -11,3 +11,25 @@ T_lim!(T_lim, U, p, t) T_exp!(T_exp, U, p, t) end + +@inline function compute_T_lim_T_exp!( + T_lim, + T_exp, + U, + p, + t, + T_lim!, + T_exp!, + ::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded}, +) + Base.@sync begin + Base.Threads.@spawn begin + T_lim!(T_lim, U, p, t) + nothing + end + Base.Threads.@spawn begin + T_exp!(T_exp, U, p, t) + nothing + end + end +end