Skip to content

Commit

Permalink
Added simple CuArray patch
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilson committed Jul 29, 2023
1 parent baf60d3 commit 0c20f1f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Songchen Tan <[email protected]>"]
version = "0.2.1"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
Expand Down
7 changes: 7 additions & 0 deletions src/derivative.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

export derivative
using CUDA

"""
derivative(f, x::T, order::Int64)
Expand All @@ -23,6 +24,12 @@ end
derivative(f, x, l, Val{order + 1}())
end

# add CUDA support
@inline function derivative(f, x::CuArray, l::CuArray,
order::Int64)
derivative(f, x, l, Val{order + 1}())
end

@inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N}
t = TaylorScalar{T, N}(x, one(x))
return extract_derivative(f(t), N)
Expand Down
19 changes: 19 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using CUDA

# cpu
f(x,y) = x^2 * y^2
input_cpu = [1e0,1e0]
derivative(temp -> f(temp[1],temp[2]),input_cpu,[1e0,0e0],2)

# gpu
f(x,y) = x^2 * y^2
input_gpu = CuArray([1e0,1e0])
direction_gpu = CuArray([1e0,0e0])
derivative(temp -> f(temp[1],temp[2]),input_gpu,direction_gpu,2)







0 comments on commit 0c20f1f

Please sign in to comment.