Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Eliminate need for skip_num_stack_entries argument in convert_to_tensors #333

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

slyubomirsky
Copy link
Collaborator

@slyubomirsky slyubomirsky commented Nov 3, 2024

Addresses issue #310. The only use of skip_num_stack_entries was for slice_helper and addressing this issue in a systematic manner would likely require building in many hacks and assumptions, so the approach here is just to manually override the stack information in that one function.

@slyubomirsky slyubomirsky added the tripy Pull request for the tripy project label Nov 3, 2024
# This check supports cases like slice_helper, where the decorated function
# is used *inside* the override and hence the wrapped call would come *before*
# the call from the registry.
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to find WRAPPER_STACK_DEPTH first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would only be needed in our test cases, unless we want to add those to the function registry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would slightly prefer adding hacks into the tests rather than into Tripy proper.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change the tests, in that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it seems my assumption that it's for overrides is incorrect. convert_to_tensors is also used on full, full_like, iota, quantize, dequantize, expand, reshape, and resize, though these appear in a function registry via export.public_api, so the check still succeeds.

Copy link
Collaborator Author

@slyubomirsky slyubomirsky Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code actually proves to be necessary in the following case: If tensor_from_shape_like is used by the decorator, that may call concatenate, which is in the function registry due to the use of export.public_api. Without the earlier loop, the inner call to concatenate is selected first and the stack entry is chosen. This arises in various test cases that use tensor_initializers, e.g., test_unsqueeze. I will see if that case can be excluded cleanly.

Edit: That can be rectified by excluding tripy.frontend.utils from the stack entry matching as well. This seems like a bit of a hack now because fundamentally, what we do want to find is the call to the decorated function and from there find the first call to user code, which may require skipping to the next registered function (i.e., a front-end API used by the user). This becomes muddled if multiple registered functions are used during a single call, like with the case of concatenate. Compare the most recent commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think skipping stack entries is fundamentally broken. Unless the function signature is the same, it's difficult to reliably determine column information. Could we just update the slice implementation to add column information without the help of this decorator and then disallow skipping stack entries here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll keep the old behavior of checking for the call to the decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, probably fewer hacks needed overall

@@ -84,11 +85,23 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n

# Find the first caller of this function that is NOT the function registry.
# Also save the last dispatch target we see.

# Start from the registry. It will always be present except for tests,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more place where it could be missing is methods of classes. If that becomes an issue, then maybe we could update the function registry to automatically discover and wrap methods too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be an example of a class method for which we would use the decorator? I thought we were only using it for overrides of magic methods.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it also handles conversion to shapes now, there could be cases in the future where we want to use it on a method. Could you add a note in the docstring of the decorator mentioning the assumptions it makes about the call stack?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully it won't be necessary if we remove these hacks per the above discussion, will note whatever assumptions ultimately remain.

…nore registered functions that come from the decorator
@slyubomirsky slyubomirsky force-pushed the remove-skip-num-stack-entries branch 2 times, most recently from 7e8a1e3 to 6db25c7 Compare November 6, 2024 22:37
…old logic, only overriding the special case in slice
Comment on lines 267 to 272
# Look for the call to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__.
frame_index = -1
for idx, source_info in enumerate(arg.stack_info):
if source_info._dispatch_target == "__getitem__":
frame_index = idx + 1
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can lift this out of the loop right? The frame for __getitem__ should be the same for all the arguments to slice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I had overlooked that. That's correct.

# Look for the stack frame index to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__.
# It will be the same for all the slice params
frame_index = -1
if slice_params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the case where there are no slice_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it to an assert. I don't think it's possible for it to be empty but it's best to avoid an indexing error in case it somehow happens.

# the original call to user code.
@frontend_utils.convert_to_tensors(skip_num_stack_entries=1)
@frontend_utils.convert_to_tensors()
def slice_helper(tensor, *slice_params: TensorLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it would be better to do the conversion to tensors and the stack info in the caller? The reason skip_num_stack_entries is unreliable is that the signature of the API can differ from that of the decorated function, which breaks our AST parsing. By still having the helper, it seems like we would have the same problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run into a case where that's actually happened? The error messages in the test cases generate correctly.

We could indeed get rid of the decorator altogether and handle everything directly in __getitem__. The reason for having the helper is that the decorator normally would not be able to handle slice inputs and we would have to preprocess them. In any case, we do need to know which inputs are slices and which ones are ints to determine things like distinguishing between a slice of size 1 (returns a list of length 1) and a single index (returns a single entry).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the code that would be parsed would be the user's subscripted tensor, i.e., the subscript AST node, whose signature would not match that of our override for __getitem__ regardless. We could pull in the portions of the code from the decorator that convert values into tensors, but it would not be much different in principle from what the decorator is already doing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tripy Pull request for the tripy project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants