-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tripy] Eliminate need for skip_num_stack_entries
argument in convert_to_tensors
#333
base: main
Are you sure you want to change the base?
Conversation
tripy/tripy/frontend/utils.py
Outdated
# This check supports cases like slice_helper, where the decorated function | ||
# is used *inside* the override and hence the wrapped call would come *before* | ||
# the call from the registry. | ||
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to find WRAPPER_STACK_DEPTH
first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would only be needed in our test cases, unless we want to add those to the function registry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would slightly prefer adding hacks into the tests rather than into Tripy proper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change the tests, in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, it seems my assumption that it's for overrides is incorrect. convert_to_tensors
is also used on full
, full_like
, iota
, quantize
, dequantize
, expand
, reshape
, and resize
, though these appear in a function registry via export.public_api
, so the check still succeeds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code actually proves to be necessary in the following case: If tensor_from_shape_like
is used by the decorator, that may call concatenate
, which is in the function registry due to the use of export.public_api
. Without the earlier loop, the inner call to concatenate
is selected first and the stack entry is chosen. This arises in various test cases that use tensor_initializers
, e.g., test_unsqueeze
. I will see if that case can be excluded cleanly.
Edit: That can be rectified by excluding tripy.frontend.utils
from the stack entry matching as well. This seems like a bit of a hack now because fundamentally, what we do want to find is the call to the decorated function and from there find the first call to user code, which may require skipping to the next registered function (i.e., a front-end API used by the user). This becomes muddled if multiple registered functions are used during a single call, like with the case of concatenate
. Compare the most recent commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think skipping stack entries is fundamentally broken. Unless the function signature is the same, it's difficult to reliably determine column information. Could we just update the slice implementation to add column information without the help of this decorator and then disallow skipping stack entries here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll keep the old behavior of checking for the call to the decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, probably fewer hacks needed overall
tripy/tripy/frontend/utils.py
Outdated
@@ -84,11 +85,23 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n | |||
|
|||
# Find the first caller of this function that is NOT the function registry. | |||
# Also save the last dispatch target we see. | |||
|
|||
# Start from the registry. It will always be present except for tests, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more place where it could be missing is methods of classes. If that becomes an issue, then maybe we could update the function registry to automatically discover and wrap methods too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would there be an example of a class method for which we would use the decorator? I thought we were only using it for overrides of magic methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it also handles conversion to shapes now, there could be cases in the future where we want to use it on a method. Could you add a note in the docstring of the decorator mentioning the assumptions it makes about the call stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully it won't be necessary if we remove these hacks per the above discussion, will note whatever assumptions ultimately remain.
…nore registered functions that come from the decorator
7e8a1e3
to
6db25c7
Compare
…old logic, only overriding the special case in slice
6db25c7
to
737b7f7
Compare
# Look for the call to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__. | ||
frame_index = -1 | ||
for idx, source_info in enumerate(arg.stack_info): | ||
if source_info._dispatch_target == "__getitem__": | ||
frame_index = idx + 1 | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can lift this out of the loop right? The frame for __getitem__
should be the same for all the arguments to slice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I had overlooked that. That's correct.
# Look for the stack frame index to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__. | ||
# It will be the same for all the slice params | ||
frame_index = -1 | ||
if slice_params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the case where there are no slice_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change it to an assert. I don't think it's possible for it to be empty but it's best to avoid an indexing error in case it somehow happens.
# the original call to user code. | ||
@frontend_utils.convert_to_tensors(skip_num_stack_entries=1) | ||
@frontend_utils.convert_to_tensors() | ||
def slice_helper(tensor, *slice_params: TensorLike): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it would be better to do the conversion to tensors and the stack info in the caller? The reason skip_num_stack_entries
is unreliable is that the signature of the API can differ from that of the decorated function, which breaks our AST parsing. By still having the helper, it seems like we would have the same problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run into a case where that's actually happened? The error messages in the test cases generate correctly.
We could indeed get rid of the decorator altogether and handle everything directly in __getitem__
. The reason for having the helper is that the decorator normally would not be able to handle slice
inputs and we would have to preprocess them. In any case, we do need to know which inputs are slice
s and which ones are int
s to determine things like distinguishing between a slice of size 1 (returns a list of length 1) and a single index (returns a single entry).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, the code that would be parsed would be the user's subscripted tensor, i.e., the subscript AST node, whose signature would not match that of our override for __getitem__
regardless. We could pull in the portions of the code from the decorator that convert values into tensors, but it would not be much different in principle from what the decorator is already doing.
Addresses issue #310. The only use of
skip_num_stack_entries
was forslice_helper
and addressing this issue in a systematic manner would likely require building in many hacks and assumptions, so the approach here is just to manually override the stack information in that one function.