Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BugFix: adjoint metric tensor with jax #5271

Merged
merged 5 commits into from
Feb 28, 2024
Merged

Conversation

astralcai
Copy link
Contributor

Context:
The adjoint_metric_tensor transform does not work with jax variables because jax variables are not considered trainable parameters until they become tracers.

Description of the Change:

  1. Add a custom expand transform to adjoint_metric_tensor that expands trainable parameters based on argnums.
  2. Add an optional use_argnum_in_expand argument to transform programs to determine whether or not to perform jax_argnums_to_tape_trainable on the parameters.

Benefits:
BugFix

Possible Drawbacks:
Adding yet another keyword argument to transform may make code look messy.

Related GitHub Issues:
#5197

Related Shortcut Stories:
[sc-56734]

Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@trbromley trbromley added this to the v0.35 milestone Feb 27, 2024
@astralcai astralcai requested review from rmoyard and a team February 27, 2024 20:34
@astralcai astralcai marked this pull request as ready for review February 27, 2024 20:34
Copy link

codecov bot commented Feb 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

❗ No coverage uploaded for pull request base (v0.35.0-rc0@c1f3997). Click here to learn what that means.

Additional details and impacted files
@@              Coverage Diff               @@
##             v0.35.0-rc0    #5271   +/-   ##
==============================================
  Coverage               ?   99.65%           
==============================================
  Files                  ?      399           
  Lines                  ?    36617           
  Branches               ?        0           
==============================================
  Hits                   ?    36489           
  Misses                 ?      128           
  Partials               ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good! I don't think it is necessary to argnums as arguments though

pennylane/gradients/adjoint_metric_tensor.py Outdated Show resolved Hide resolved
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @astralcai , it looks good to me 💯

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@astralcai astralcai enabled auto-merge (squash) February 28, 2024 21:57
@astralcai astralcai merged commit d4dacf9 into v0.35.0-rc0 Feb 28, 2024
37 checks passed
@astralcai astralcai deleted the adj-metric-tensor branch February 28, 2024 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants