-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@torch.compile
some tutorials
#2984
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2984
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 344861d with merge base f1c0b8a (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from a technical perspective, although is it fine to import profile_utils.py in a tutorial? @svekars
@@ -159,10 +162,16 @@ def compute_loss(params, buffers, sample, target): | |||
|
|||
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) | |||
|
|||
@torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this function to below the text "Finally, let's use..."?
I can remove or hide the import in an environment flag. |
54771dc
to
d8e6e12
Compare
How will this run in Google Colab? |
It doesn't have to. I've removed the calls to the profiling function. It was there just to make sure we can get any speedup by compiling the model. |
I see you removed |
No, I've removed it as well. |
from torch._dynamo import config | ||
config.inline_inbuilt_nn_modules = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove, also, don't merge this PR until the next release (2.5)
@@ -125,7 +127,11 @@ def fmodel(params, buffers, x): | |||
|
|||
from torch import vmap | |||
|
|||
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) | |||
@torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main comment is: all of these tutorials should have a separate section at the end that says "let's try to use torch.compile, and here are the speedups".
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Description
This PR attempts to compile three tutorials:
To compile with
fullgraph=True
, one needs pytorch with the changes from pytorch/pytorch#129091.Performance gain
This is not a scientific benchmark, as the inputs to the models are small and can suffer from noise. But we did have some gains on
ensembling
.neural_tangent_kernels
ensembling
per_sample_grads
cc @williamwen42 @msaroufim @anijain2305