-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lookaside for torch.ops.higher_order.autograd_function_apply
#1256
base: main
Are you sure you want to change the base?
Conversation
b4647ed
to
71db6cd
Compare
thunder/torch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep this code? I think it does a good job at separation of concerns. The job if "jit_ext" is to make things Thunder friendly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it feels like we should have a better utility to make a callable friendly to Thunder before doing what this PR is trying to do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have already done the complicated part of making callables friendly to Thunder. What I mean is that it's now possible to remove a bit of complexity of registering a grad rule by using the functions that were removed:
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index d98836a2..9d390e7d 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -864,9 +864,18 @@ def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_a
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)
@wraps(bwd_trace.python_callable())
- def bwd_impl_callable(*args, **kwargs):
+ def bwd_impl_callable(ctx, *args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)
+ @wraps(augmented_fwd_trace.python_callable())
+ def fwd_impl_callable(ctx, *args, **kwargs):
+ return thunder.core.trace_interpreter.interpret_trace(augmented_fwd_trace, *args, **kwargs)
+
+ from thunder.torch import autograd_function_apply
+ wrapped_fwd = wrap_const(fwd_impl_callable)
+ wrapped_bwd = wrap_const(bwd_impl_callable)
+ return interpreter_needs_wrap(autograd_function_apply)(wrapped_fwd, wrapped_bwd, *fwd_args, **fwd_kwargs)
+
@wraps(core_of_fwd)
def grad_transform(*args, **kwargs):
from thunder.core.transforms import get_grad, put_grads
Signed-off-by: Masaki Kozuki <[email protected]>
627845d
to
f349308
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
9c51ae2
to
94c3409
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
c5621cd
to
dd702f5
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
dd702f5
to
1b85a21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want the lookasides' scope to be limited only to the preprocessing of PyTorch code. If the removed code is reused in the updated lookaside we'll achieve that.
thunder/core/jit_ext.py
Outdated
from thunder.core import utils | ||
from thunder.core.baseutils import sequencify | ||
from thunder.core.pytree import tree_flatten, tree_map | ||
from thunder.core.transforms import VJPDual, augmented_forward_impls, backward_impls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These imports are unused.
thunder/core/jit_ext.py
Outdated
if p in producer_map: | ||
prod_bsym = producer_map[p] | ||
tensor_to_prod_bsym[variableify(p)] = prod_bsym | ||
prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable is unused.
thunder/core/jit_ext.py
Outdated
return aug_fwd_trace | ||
aug_fwd_result = aug_fwd_trace.output | ||
output, saved_values = unwrap(aug_fwd_result) | ||
wrapped_output = wrap(output, provenance=aug_fwd_provenance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable is unused.
thunder/torch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have already done the complicated part of making callables friendly to Thunder. What I mean is that it's now possible to remove a bit of complexity of registering a grad rule by using the functions that were removed:
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index d98836a2..9d390e7d 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -864,9 +864,18 @@ def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_a
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)
@wraps(bwd_trace.python_callable())
- def bwd_impl_callable(*args, **kwargs):
+ def bwd_impl_callable(ctx, *args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)
+ @wraps(augmented_fwd_trace.python_callable())
+ def fwd_impl_callable(ctx, *args, **kwargs):
+ return thunder.core.trace_interpreter.interpret_trace(augmented_fwd_trace, *args, **kwargs)
+
+ from thunder.torch import autograd_function_apply
+ wrapped_fwd = wrap_const(fwd_impl_callable)
+ wrapped_bwd = wrap_const(bwd_impl_callable)
+ return interpreter_needs_wrap(autograd_function_apply)(wrapped_fwd, wrapped_bwd, *fwd_args, **fwd_kwargs)
+
@wraps(core_of_fwd)
def grad_transform(*args, **kwargs):
from thunder.core.transforms import get_grad, put_grads
thunder/core/jit_ext.py
Outdated
@wraps(core_of_fwd) | ||
def grad_transform(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part can be removed if we reuse existing functions and use interpreter_needs_wrap(autograd_function_apply)
here.
thunder/torch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not remove this code and use it inside the lookaside.
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
1b85a21
to
7729af1
Compare
What does this PR do?
As per #1248, the support of
torch.ops.higher_order.autograd_function_apply
would be a bit more flexible by tracing into bothfwd
andbwd
.