-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support non-standard scalars in test_scalar #61
base: main
Are you sure you want to change the base?
Conversation
Not supported in Julia 1.0
Not supported by Julia v1
will review tomorrow |
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im | ||
@testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs) | ||
frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||
if !isa(Δz, Real) && i == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be:
if !isa(Δz, Real) && i == 1 | |
if !isa(Δz, Real) && length(Δzs) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, no. i == 1
when when the given tangent vector is purely real, even if it isn't a Real
. And this test checks that using an actually Real
tangent vector gives the same result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i would like to look at this again after the comments are addressed.
I am not sure i properly undestand what is going on for the
if !isa(Δz, Real) && i == 1
branches,
and I think I would be better able to, if we have split this into two methods, one for real and one not for real.
src/testers.jl
Outdated
vΩ, Ω_from_vec = to_vec(Ω) | ||
# orthonormal cotangent vectors | ||
vΩ_basis = Diagonal(ones(eltype(vΩ), length(vΩ))) | ||
ΔΩs = [Ω_from_vec(vΩ_basis[:, i]) for i in axes(vΩ_basis, 2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move this out into a helper function basis_vectors
@@ -112,61 +112,55 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid | |||
|
|||
`fkwargs` are passed to `f` as keyword arguments. | |||
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. | |||
|
|||
To use this tester for a scalar type `MyNumber <: AbstractNumber`, | |||
`FiniteDifferences.to_vec(::MyNumber)` must be implemented. | |||
""" | |||
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify this code by defining a seperate method for:
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) | |
function test_scalar(f, z::Real; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would simplify the frule
test because we wouldn't need the basis, but if the output is non-real we still need the basis on the output for the rrule
test. Adding a separate method would require us to maintain that code in two places.
function FiniteDifferences.to_vec(q::Quaternion) | ||
function Quaternion_from_vec(q_vec) | ||
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) | ||
end | ||
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move this to be defined in the package itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean define this in Quaternions or FiniteDifferences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or ChainRulesTestUtils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FiniteDifferences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @oxinabox what do you think?
return quatfun(q), quatfun_pullback | ||
end | ||
|
||
q = quatrand() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we define rand_tangent(:: Quaternion)
in this package also?
@willtebbutt do you have plans around further advancing rand_tangent
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not currently -- not sure that there's much to do beyond integrating it in with ChainRulesTestUtils
in some way or another and continuing to add new methods where necessary.
|
||
Get a set of basis (co)tangent vectors for `x`. | ||
|
||
This function assumes that the (co)tangent vectors are of the same type as `x` and requires |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to give basis vectors of the same type as the result of rand_tangent
instead, but I'm not certain how to do that.
Is this ready for rereview? |
Not yet, I'll try to finish it up this week. |
test_scalar
currently is veryReal
andComplex
focused. This PR generalizestest_scalar
to work the same for any scalar for whichFiniteDifferences.to_vec
(and a handful of base functions) are implemented.We test it with
Quaternions.Quaternion
. We'd ideally test against a more minimal number, but it turns out one needs to implement quite a few base methods to get a new number to work correctly.