From b7a4f332687a3a2182c8c47cc1cc6a272baf79e6 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Wed, 15 Nov 2023 13:58:42 -0800 Subject: [PATCH] Add support for copyto! and fill! for DataF (Used for Point spaces) --- src/DataLayouts/cuda.jl | 12 ++++++++++++ test/DataLayouts/cuda.jl | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/src/DataLayouts/cuda.jl b/src/DataLayouts/cuda.jl index a24fdccf9f..a3a4be35d1 100644 --- a/src/DataLayouts/cuda.jl +++ b/src/DataLayouts/cuda.jl @@ -144,3 +144,15 @@ function Base.fill!(dest::VF{S, A}, val) where {S, A <: CUDA.CuArray} end return dest end + +function Base.copyto!( + dest::DataF{S}, + bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}}, +) where {S, A <: CUDA.CuArray} + CUDA.@cuda threads = (1, 1) blocks = (1, 1) knl_copyto!(dest, bc) + return dest +end +function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray} + CUDA.@cuda threads = (1, 1) blocks = (1, 1) knl_fill!(dest, val) + return dest +end diff --git a/test/DataLayouts/cuda.jl b/test/DataLayouts/cuda.jl index 2d18577465..f5e1d1231b 100644 --- a/test/DataLayouts/cuda.jl +++ b/test/DataLayouts/cuda.jl @@ -72,4 +72,8 @@ end @test Array(parent(data)) == FT[ f == 1 ? 1 : 2 for v in 1:Nv, i in 1:4, j in 1:4, f in 1:2, h in 1:3 ] + + data = DataF{S}(CuArray{FT}) + data .= Complex(1.0, 2.0) + @test Array(parent(data)) == FT[f == 1 ? 1 : 2 for f in 1:2] end