diff --git a/Project.toml b/Project.toml index 30a3142..0591422 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.5.7" +version = "0.5.8" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/docs/src/index.md b/docs/src/index.md index e2d0f37..91f90f7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -22,6 +22,7 @@ ClimaComms.@threaded ClimaComms.@time ClimaComms.@elapsed ClimaComms.@sync +ClimaComms.@cuda_sync ``` ## Contexts diff --git a/src/devices.jl b/src/devices.jl index a994bb0..d05efbf 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -178,6 +178,25 @@ for CPU devices and CUDA.@sync expr ``` for CUDA devices. + +An example use-case of this might be: +```julia +BenchmarkTools.@benchmark begin + if ClimaComms.device() isa ClimaComms.CUDADevice + CUDA.@sync begin + launch_cuda_kernels_or_spawn_tasks!(...) + end + elseif ClimaComms.device() isa ClimaComms.CPUMultiThreading + Base.@sync begin + launch_cuda_kernels_or_spawn_tasks!(...) + end + end +end +``` + +If the CPU version of the above example does not leverage +spawned tasks (which require using `Base.sync` or `Threads.wait` +to synchronize), then you may want to simply use [`@cuda_sync`](@ref). """ macro sync(device, expr) # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 @@ -194,3 +213,32 @@ macro sync(device, expr) end end) end + +""" + @cuda_sync device expr + +Device-flexible `CUDA.@sync`. + +Lowers to +```julia +expr +``` +for CPU devices and +```julia +CUDA.@sync expr +``` +for CUDA devices. +""" +macro cuda_sync(device, expr) + # https://github.com/JuliaLang/julia/issues/28979#issuecomment-1756145207 + return esc(quote + if $(device) isa $CUDADevice + $CUDA.@sync begin + $(expr) + end + else + @assert $(device) isa $AbstractDevice + $(expr) + end + end) +end diff --git a/test/hygiene.jl b/test/hygiene.jl index b1a794e..645ed59 100644 --- a/test/hygiene.jl +++ b/test/hygiene.jl @@ -18,6 +18,10 @@ function test_macro_hyhiene(dev) CC.@sync dev for i in 1:n sin.(rand(10)) end + + CC.@cuda_sync dev for i in 1:n + sin.(rand(10)) + end end dev = CC.device()