Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lookaside for torch.ops.higher_order.autograd_function_apply #1256

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Oct 3, 2024

What does this PR do?

As per #1248, the support of torch.ops.higher_order.autograd_function_apply would be a bit more flexible by tracing into both fwd and bwd.

thunder/core/jit_ext.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch from b4647ed to 71db6cd Compare October 3, 2024 14:42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep this code? I think it does a good job at separation of concerns. The job if "jit_ext" is to make things Thunder friendly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it feels like we should have a better utility to make a callable friendly to Thunder before doing what this PR is trying to do

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have already done the complicated part of making callables friendly to Thunder. What I mean is that it's now possible to remove a bit of complexity of registering a grad rule by using the functions that were removed:

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index d98836a2..9d390e7d 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -864,9 +864,18 @@ def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_a
     bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)
 
     @wraps(bwd_trace.python_callable())
-    def bwd_impl_callable(*args, **kwargs):
+    def bwd_impl_callable(ctx, *args, **kwargs):
         return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)
 
+    @wraps(augmented_fwd_trace.python_callable())
+    def fwd_impl_callable(ctx, *args, **kwargs):
+        return thunder.core.trace_interpreter.interpret_trace(augmented_fwd_trace, *args, **kwargs)
+
+    from thunder.torch import autograd_function_apply
+    wrapped_fwd = wrap_const(fwd_impl_callable)
+    wrapped_bwd = wrap_const(bwd_impl_callable)
+    return interpreter_needs_wrap(autograd_function_apply)(wrapped_fwd, wrapped_bwd, *fwd_args, **fwd_kwargs)
+
     @wraps(core_of_fwd)
     def grad_transform(*args, **kwargs):
         from thunder.core.transforms import get_grad, put_grads

thunder/core/jit_ext.py Outdated Show resolved Hide resolved
crcrpar added a commit that referenced this pull request Nov 13, 2024
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch from 627845d to f349308 Compare November 13, 2024 11:04
crcrpar added a commit that referenced this pull request Nov 13, 2024
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch 2 times, most recently from 9c51ae2 to 94c3409 Compare November 13, 2024 11:29
crcrpar added a commit that referenced this pull request Nov 13, 2024
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch from c5621cd to dd702f5 Compare November 13, 2024 12:24
crcrpar added a commit that referenced this pull request Nov 14, 2024
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch from dd702f5 to 1b85a21 Compare November 14, 2024 13:04
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want the lookasides' scope to be limited only to the preprocessing of PyTorch code. If the removed code is reused in the updated lookaside we'll achieve that.

from thunder.core import utils
from thunder.core.baseutils import sequencify
from thunder.core.pytree import tree_flatten, tree_map
from thunder.core.transforms import VJPDual, augmented_forward_impls, backward_impls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These imports are unused.

if p in producer_map:
prod_bsym = producer_map[p]
tensor_to_prod_bsym[variableify(p)] = prod_bsym
prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is unused.

return aug_fwd_trace
aug_fwd_result = aug_fwd_trace.output
output, saved_values = unwrap(aug_fwd_result)
wrapped_output = wrap(output, provenance=aug_fwd_provenance)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is unused.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have already done the complicated part of making callables friendly to Thunder. What I mean is that it's now possible to remove a bit of complexity of registering a grad rule by using the functions that were removed:

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index d98836a2..9d390e7d 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -864,9 +864,18 @@ def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_a
     bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)
 
     @wraps(bwd_trace.python_callable())
-    def bwd_impl_callable(*args, **kwargs):
+    def bwd_impl_callable(ctx, *args, **kwargs):
         return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)
 
+    @wraps(augmented_fwd_trace.python_callable())
+    def fwd_impl_callable(ctx, *args, **kwargs):
+        return thunder.core.trace_interpreter.interpret_trace(augmented_fwd_trace, *args, **kwargs)
+
+    from thunder.torch import autograd_function_apply
+    wrapped_fwd = wrap_const(fwd_impl_callable)
+    wrapped_bwd = wrap_const(bwd_impl_callable)
+    return interpreter_needs_wrap(autograd_function_apply)(wrapped_fwd, wrapped_bwd, *fwd_args, **fwd_kwargs)
+
     @wraps(core_of_fwd)
     def grad_transform(*args, **kwargs):
         from thunder.core.transforms import get_grad, put_grads

Comment on lines 870 to 871
@wraps(core_of_fwd)
def grad_transform(*args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can be removed if we reuse existing functions and use interpreter_needs_wrap(autograd_function_apply) here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not remove this code and use it inside the lookaside.

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/lookaside_autograd-function-apply branch from 1b85a21 to 7729af1 Compare November 14, 2024 14:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants