-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BugFix: adjoint metric tensor with jax #5271
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## v0.35.0-rc0 #5271 +/- ##
==============================================
Coverage ? 99.65%
==============================================
Files ? 399
Lines ? 36617
Branches ? 0
==============================================
Hits ? 36489
Misses ? 128
Partials ? 0 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good! I don't think it is necessary to argnums as arguments though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @astralcai , it looks good to me 💯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
Context:
The
adjoint_metric_tensor
transform does not work with jax variables because jax variables are not considered trainable parameters until they become tracers.Description of the Change:
adjoint_metric_tensor
that expands trainable parameters based on argnums.use_argnum_in_expand
argument to transform programs to determine whether or not to performjax_argnums_to_tape_trainable
on the parameters.Benefits:
BugFix
Possible Drawbacks:
Adding yet another keyword argument to
transform
may make code look messy.Related GitHub Issues:
#5197
Related Shortcut Stories:
[sc-56734]