Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@torch.compile some tutorials #2984

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

guilhermeleobas
Copy link

@guilhermeleobas guilhermeleobas commented Jul 25, 2024

Description

This PR attempts to compile three tutorials:

  • neural_tangent_kernels
  • ensembling
  • per_sample_grads

To compile with fullgraph=True, one needs pytorch with the changes from pytorch/pytorch#129091.

Performance gain

This is not a scientific benchmark, as the inputs to the models are small and can suffer from noise. But we did have some gains on ensembling.

neural_tangent_kernels
[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  149.0
      eager           |   19.1
2 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  146.6
      eager           |   19.2
4 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  148.2
      eager           |   19.2
8 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  148.6
      eager           |   19.2
16 threads: ------------------
      @torch.compile  |   19.2
      make_fx         |  149.1
      eager           |   19.1

Times are in milliseconds (ms).

[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   17.6
      make_fx         |  145.1
      eager           |   17.9
2 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  142.7
      eager           |   18.0
4 threads: -------------------
      @torch.compile  |   17.7
      make_fx         |  142.3
      eager           |   18.1
8 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  144.6
      eager           |   17.8
16 threads: ------------------
      @torch.compile  |   17.7
      make_fx         |  144.8
      eager           |   18.0

Times are in milliseconds (ms).

[-- empirical_ntk_ntk_vps ---]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   62.3
      make_fx         |  123.0
      eager           |   62.9
2 threads: -------------------
      @torch.compile  |   62.4
      make_fx         |  123.6
      eager           |   63.1
4 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  122.9
      eager           |   63.1
8 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.2
16 threads: ------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.1

Times are in milliseconds (ms).
ensembling
[---- compute_predictions1 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    149.9
      make_fx         |  14963.3
      eager           |    303.4
2 threads: ---------------------
      @torch.compile  |    151.4
      make_fx         |  14664.8
      eager           |    326.5
4 threads: ---------------------
      @torch.compile  |    152.4
      make_fx         |  14680.3
      eager           |    327.3
8 threads: ---------------------
      @torch.compile  |    164.1
      make_fx         |  14694.9
      eager           |    332.6
16 threads: --------------------
      @torch.compile  |    151.9
      make_fx         |  14633.4
      eager           |    317.7

Times are in microseconds (us).

[---- compute_predictions2 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    147.0
      make_fx         |  14995.9
      eager           |    299.6
2 threads: ---------------------
      @torch.compile  |    149.0
      make_fx         |  14984.4
      eager           |    293.9
4 threads: ---------------------
      @torch.compile  |    147.1
      make_fx         |  15036.3
      eager           |    327.0
8 threads: ---------------------
      @torch.compile  |    150.4
      make_fx         |  14985.6
      eager           |    330.8
16 threads: --------------------
      @torch.compile  |    151.3
      make_fx         |  15123.6
      eager           |    301.2

Times are in microseconds (us).
per_sample_grads
[--- vmap_ft_compute_grad --]
                      |  cuda
1 threads: ------------------
      @torch.compile  |   6.5
      make_fx         |  48.1
      eager           |   7.1
2 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
4 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.4
8 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
16 threads: -----------------
      @torch.compile  |   6.0
      make_fx         |  48.1
      eager           |   6.5

Times are in milliseconds (ms).

cc @williamwen42 @msaroufim @anijain2305

Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2984

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 344861d with merge base f1c0b8a (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@guilhermeleobas guilhermeleobas marked this pull request as ready for review July 25, 2024 20:45
@svekars svekars added the torch.compile Torch compile and other relevant tutorials label Jul 29, 2024
Copy link
Member

@williamwen42 williamwen42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from a technical perspective, although is it fine to import profile_utils.py in a tutorial? @svekars

@@ -159,10 +162,16 @@ def compute_loss(params, buffers, sample, target):

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

@torch.compile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this function to below the text "Finally, let's use..."?

@guilhermeleobas
Copy link
Author

LGTM from a technical perspective, although is it fine to import profile_utils.py in a tutorial? @svekars

I can remove or hide the import in an environment flag.

@svekars
Copy link
Contributor

svekars commented Aug 2, 2024

How will this run in Google Colab?

@guilhermeleobas
Copy link
Author

How will this run in Google Colab?

It doesn't have to. I've removed the calls to the profiling function. It was there just to make sure we can get any speedup by compiling the model.

@williamwen42
Copy link
Member

I see you removed profile_utils.py from the tutorial - do we still need to add it in this PR then?

@guilhermeleobas
Copy link
Author

I see you removed profile_utils.py from the tutorial - do we still need to add it in this PR then?

No, I've removed it as well.

@zou3519 zou3519 self-requested a review August 5, 2024 15:40
Comment on lines +28 to +29
from torch._dynamo import config
config.inline_inbuilt_nn_modules = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, also, don't merge this PR until the next release (2.5)

@@ -125,7 +127,11 @@ def fmodel(params, buffers, x):

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
@torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main comment is: all of these tutorials should have a separate section at the end that says "let's try to use torch.compile, and here are the speedups".

Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the stale Stale PRs label Oct 12, 2024
@svekars svekars added stale Stale PRs and removed stale Stale PRs labels Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed stale Stale PRs torch.compile Torch compile and other relevant tutorials
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants