Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added simple CuArray patch #44

Closed

Conversation

jacob-m-wilson-42
Copy link

This is my first ever pull request so please feel free to comment on any mistakes.

I wanted to get something up quick, there may be a better solution but I added another method to the derivatives source file to accept any CuArray type. I also added a simple test (which throws scalar index warnings). I also added the CUDA dependency.

Please triple check anything before merging; this is my first commit ever!

@jacob-m-wilson-42
Copy link
Author

jacob-m-wilson-42 commented Jul 29, 2023

Hmmm I didn't change much to the source. Not sure why it broke so many things...

It looks like there might be some incompatible packages, but I'm not experienced enough to say for certain. If nothing else, maybe the updated derivative() file will save someone a little bit of time.

@tansongchen
Copy link
Member

Hi Jacob, sorry for the late reply and thanks for your contribution. Sure, we should support CUDA data types, but as a generic AD library, it would be better if we don't explicitly write out very specific types like CuArray -- maybe some more abstract ones would also work?

Plus, someone just relaxed the input data type to AbstractArray{T, 1} in this PR #45 . Could you try the most recent version on the main branch?

@tansongchen
Copy link
Member

Hi Jacob, I just updated the main branch. Now the signature looks like

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
    order::Int64) where {T <: Number, S <: Number}
    derivative(f, x, l, Val{order + 1}())
end

Here x and l are allowed to be very different types, as long as they are both some AbstractVectors and S can be converted to T. I am not familiar with CuArrays but I believe this can solve your problem. Could you try the latest branch?

@jacob-m-wilson-42
Copy link
Author

Hello Songchen, also sorry for the late response. I will give it a shot when I get a chance and let you know how it goes! I think relaxing the type is probably better like you suggested.

@jacob-m-wilson-42
Copy link
Author

Songchen,

I tested the main branch code with the below

v, direction = CuArray([0f0, 0f0]), CuArray([1.0f0, 0.0f0])
derivative(x -> sum(exp.(x)), v, direction, 2) # directional derivative

and it worked! Excellent work, thanks so much for the addition! Sorry it took so long for me to see this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants