Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Eliminate need for skip_num_stack_entries argument in convert_to_tensors #333

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tripy/tripy/frontend/trace/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,6 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
return out


# Because the helper is called inside another function, we need to skip one entry in the call stack to find
# the original call to user code.
@frontend_utils.convert_to_tensors(skip_num_stack_entries=1)
@frontend_utils.convert_to_tensors()
def slice_helper(tensor, *slice_params: TensorLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it would be better to do the conversion to tensors and the stack info in the caller? The reason skip_num_stack_entries is unreliable is that the signature of the API can differ from that of the decorated function, which breaks our AST parsing. By still having the helper, it seems like we would have the same problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run into a case where that's actually happened? The error messages in the test cases generate correctly.

We could indeed get rid of the decorator altogether and handle everything directly in __getitem__. The reason for having the helper is that the decorator normally would not be able to handle slice inputs and we would have to preprocess them. In any case, we do need to know which inputs are slices and which ones are ints to determine things like distinguishing between a slice of size 1 (returns a list of length 1) and a single index (returns a single entry).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the code that would be parsed would be the user's subscripted tensor, i.e., the subscript AST node, whose signature would not match that of our override for __getitem__ regardless. We could pull in the portions of the code from the decorator that convert values into tensors, but it would not be much different in principle from what the decorator is already doing.

return Slice.build(inputs=[tensor, *slice_params])
35 changes: 17 additions & 18 deletions tripy/tripy/frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def empty_buffer():


# Try to include correct column offsets for non-tensor arguments.
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_num_stack_entries, arg_names):
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names):
from tripy import function_registry
from tripy.frontend.tensor import Tensor

assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances"
Expand All @@ -84,11 +85,23 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n

# Find the first caller of this function that is NOT the function registry.
# Also save the last dispatch target we see.

# Start from the registry. It will always be present except for tests,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more place where it could be missing is methods of classes. If that becomes an issue, then maybe we could update the function registry to automatically discover and wrap methods too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be an example of a class method for which we would use the decorator? I thought we were only using it for overrides of magic methods.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it also handles conversion to shapes now, there could be cases in the future where we want to use it on a method. Could you add a note in the docstring of the decorator mentioning the assumptions it makes about the call stack?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully it won't be necessary if we remove these hacks per the above discussion, will note whatever assumptions ultimately remain.

# since the decorator is intended only for overrides of magic functions.
# This check supports cases like slice_helper, where the decorated function
# is used *inside* the override and hence the wrapped call would come *before*
# the call from the registry.
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to find WRAPPER_STACK_DEPTH first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would only be needed in our test cases, unless we want to add those to the function registry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would slightly prefer adding hacks into the tests rather than into Tripy proper.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change the tests, in that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it seems my assumption that it's for overrides is incorrect. convert_to_tensors is also used on full, full_like, iota, quantize, dequantize, expand, reshape, and resize, though these appear in a function registry via export.public_api, so the check still succeeds.

Copy link
Collaborator Author

@slyubomirsky slyubomirsky Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code actually proves to be necessary in the following case: If tensor_from_shape_like is used by the decorator, that may call concatenate, which is in the function registry due to the use of export.public_api. Without the earlier loop, the inner call to concatenate is selected first and the stack entry is chosen. This arises in various test cases that use tensor_initializers, e.g., test_unsqueeze. I will see if that case can be excluded cleanly.

Edit: That can be rectified by excluding tripy.frontend.utils from the stack entry matching as well. This seems like a bit of a hack now because fundamentally, what we do want to find is the call to the decorated function and from there find the first call to user code, which may require skipping to the next registered function (i.e., a front-end API used by the user). This becomes muddled if multiple registered functions are used during a single call, like with the case of concatenate. Compare the most recent commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think skipping stack entries is fundamentally broken. Unless the function signature is the same, it's difficult to reliably determine column information. Could we just update the slice implementation to add column information without the help of this decorator and then disallow skipping stack entries here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll keep the old behavior of checking for the call to the decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, probably fewer hacks needed overall

for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH:]):
if source_info.module == function_registry.__name__:
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH + idx
break

dispatch_target = None
for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH + skip_num_stack_entries :]):
for idx, source_info in enumerate(arg.stack_info[REGISTRY_STACK_DEPTH:]):
dispatch_target = source_info._dispatch_target or dispatch_target
if source_info.module not in utils.get_module_names_to_exclude_from_stack_info():
frame_index = idx + WRAPPER_STACK_DEPTH + skip_num_stack_entries
frame_index = idx + REGISTRY_STACK_DEPTH
break
else:
# Fallback path is just to look at the user code
Expand Down Expand Up @@ -131,9 +144,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n

# NOTE: Conversion to tensors needs to be done via a decorator so that we can add stack information
# for non-tensors. Without having full context of the function signature, it is otherwise difficult to do so.
def convert_to_tensors(
targets: Set[str] = None, skip_num_stack_entries: int = 0, preprocess_args: Optional[Callable] = None
):
def convert_to_tensors(targets: Set[str] = None, preprocess_args: Optional[Callable] = None):
"""
Decorator that converts specified arguments to Tensors or DimensionSizes.
If the argument can be converted to a DimensionSize, it is. Otherwise, it is
Expand All @@ -147,17 +158,6 @@ def convert_to_tensors(
targets: Names of arguments to convert to tensors. If not supplied, any arguments annotated
with `TensorLike` or `ShapeLike` are converted.

skip_num_stack_entries: If the decorator is used on a function that is *called by*
a function that the user invokes, it will be necessary to skip stack entries
in order to get the column info from the user code. The number of entries skipped
should be equal to the nesting depth from a function called by user code
(if the decorated function is called by the user the depth is 0;
if the decorated function is called from a user function, the depth is 1; etc.).

NOTE: When using this, make sure any extra arguments to the decorated function are
passed as keyword arguments. Otherwise, the logic for determining column information
will break.

preprocess_args: A callback used to preprocess arguments before potential conversion. If provided,
this is always called, regardless of whether the decorator actually needed to perform conversion.
This will be called with all arguments that were passed to the decorated function and should
Expand Down Expand Up @@ -233,7 +233,6 @@ def add_arg(arg):
name in kwargs,
len(args),
func.__name__,
skip_num_stack_entries,
[name for name, _ in all_args],
)

Expand Down