You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello everyone! I hope this is not a simple mistake on my part, but it appears as though TaylorDiff doesn't like GPU types. I haven't worked with cuarrays for a year or two, but it looks like CuArray types now have some additional information attached to them that doesn't make them look like a vanilla array. Please see the below MWE.
using Flux
using TaylorDiff
# works
testinput =rand(3)
derivative_direction = [1e0,0e0,0e0]
model =Dense(3=>1,sin)
model(testinput) # worksderivative(modelinput ->model(modelinput)[1], testinput, derivative_direction, 2) # works# doesn't work
testinput_gpu =gpu(rand(3))
derivative_direction_gpu =gpu([1e0,0e0,0e0])
model_gpu =Dense(3=>1,sin)
model_gpu =gpu(model)
model_gpu(testinput_gpu) # works, output is still on the gpuderivative(modelinput ->model(modelinput)[1], testinput_gpu, derivative_direction_gpu, 2) # breaks here
and the resulting error is
ERROR: MethodError: no method matching derivative(::var"#27#28", ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Int64)
Closest candidates are:derivative(::Any, ::Vector{T}, ::Vector{T}, ::Int64) where T<:Number at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:21derivative(::Any, ::T, ::Int64) where T<:Number at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:17derivative(::Any, ::T, ::Val{N}) where {T<:Number, N} at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:26...
I think that the type of the CuArray is causing the issue since it has multiple fields.
Hello everyone! I hope this is not a simple mistake on my part, but it appears as though TaylorDiff doesn't like GPU types. I haven't worked with cuarrays for a year or two, but it looks like CuArray types now have some additional information attached to them that doesn't make them look like a vanilla array. Please see the below MWE.
and the resulting error is
I think that the type of the CuArray is causing the issue since it has multiple fields.
If anyone could help, I would greatly appreciate it! I'm very busy but I am willing to contribute to help fix this issue.
Thanks in advance!
The text was updated successfully, but these errors were encountered: