-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tripy] Eliminate need for skip_num_stack_entries
argument in convert_to_tensors
#333
base: main
Are you sure you want to change the base?
Changes from 1 commit
31e3452
d6a4ebb
737b7f7
cdca7ac
b839326
c405be6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,8 @@ def empty_buffer(): | |
|
||
|
||
# Try to include correct column offsets for non-tensor arguments. | ||
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_num_stack_entries, arg_names): | ||
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names): | ||
from tripy import function_registry | ||
from tripy.frontend.tensor import Tensor | ||
|
||
assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances" | ||
|
@@ -84,11 +85,23 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n | |
|
||
# Find the first caller of this function that is NOT the function registry. | ||
# Also save the last dispatch target we see. | ||
|
||
# Start from the registry. It will always be present except for tests, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One more place where it could be missing is methods of classes. If that becomes an issue, then maybe we could update the function registry to automatically discover and wrap methods too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would there be an example of a class method for which we would use the decorator? I thought we were only using it for overrides of magic methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it also handles conversion to shapes now, there could be cases in the future where we want to use it on a method. Could you add a note in the docstring of the decorator mentioning the assumptions it makes about the call stack? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully it won't be necessary if we remove these hacks per the above discussion, will note whatever assumptions ultimately remain. |
||
# since the decorator is intended only for overrides of magic functions. | ||
# This check supports cases like slice_helper, where the decorated function | ||
# is used *inside* the override and hence the wrapped call would come *before* | ||
# the call from the registry. | ||
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need to find There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would only be needed in our test cases, unless we want to add those to the function registry. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would slightly prefer adding hacks into the tests rather than into Tripy proper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll change the tests, in that case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, it seems my assumption that it's for overrides is incorrect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code actually proves to be necessary in the following case: If Edit: That can be rectified by excluding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think skipping stack entries is fundamentally broken. Unless the function signature is the same, it's difficult to reliably determine column information. Could we just update the slice implementation to add column information without the help of this decorator and then disallow skipping stack entries here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll keep the old behavior of checking for the call to the decorator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated, probably fewer hacks needed overall |
||
for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH:]): | ||
if source_info.module == function_registry.__name__: | ||
REGISTRY_STACK_DEPTH = WRAPPER_STACK_DEPTH + idx | ||
break | ||
|
||
dispatch_target = None | ||
for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH + skip_num_stack_entries :]): | ||
for idx, source_info in enumerate(arg.stack_info[REGISTRY_STACK_DEPTH:]): | ||
dispatch_target = source_info._dispatch_target or dispatch_target | ||
if source_info.module not in utils.get_module_names_to_exclude_from_stack_info(): | ||
frame_index = idx + WRAPPER_STACK_DEPTH + skip_num_stack_entries | ||
frame_index = idx + REGISTRY_STACK_DEPTH | ||
break | ||
else: | ||
# Fallback path is just to look at the user code | ||
|
@@ -131,9 +144,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n | |
|
||
# NOTE: Conversion to tensors needs to be done via a decorator so that we can add stack information | ||
# for non-tensors. Without having full context of the function signature, it is otherwise difficult to do so. | ||
def convert_to_tensors( | ||
targets: Set[str] = None, skip_num_stack_entries: int = 0, preprocess_args: Optional[Callable] = None | ||
): | ||
def convert_to_tensors(targets: Set[str] = None, preprocess_args: Optional[Callable] = None): | ||
""" | ||
Decorator that converts specified arguments to Tensors or DimensionSizes. | ||
If the argument can be converted to a DimensionSize, it is. Otherwise, it is | ||
|
@@ -147,17 +158,6 @@ def convert_to_tensors( | |
targets: Names of arguments to convert to tensors. If not supplied, any arguments annotated | ||
with `TensorLike` or `ShapeLike` are converted. | ||
|
||
skip_num_stack_entries: If the decorator is used on a function that is *called by* | ||
a function that the user invokes, it will be necessary to skip stack entries | ||
in order to get the column info from the user code. The number of entries skipped | ||
should be equal to the nesting depth from a function called by user code | ||
(if the decorated function is called by the user the depth is 0; | ||
if the decorated function is called from a user function, the depth is 1; etc.). | ||
|
||
NOTE: When using this, make sure any extra arguments to the decorated function are | ||
passed as keyword arguments. Otherwise, the logic for determining column information | ||
will break. | ||
|
||
preprocess_args: A callback used to preprocess arguments before potential conversion. If provided, | ||
this is always called, regardless of whether the decorator actually needed to perform conversion. | ||
This will be called with all arguments that were passed to the decorated function and should | ||
|
@@ -233,7 +233,6 @@ def add_arg(arg): | |
name in kwargs, | ||
len(args), | ||
func.__name__, | ||
skip_num_stack_entries, | ||
[name for name, _ in all_args], | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it would be better to do the conversion to tensors and the stack info in the caller? The reason
skip_num_stack_entries
is unreliable is that the signature of the API can differ from that of the decorated function, which breaks our AST parsing. By still having the helper, it seems like we would have the same problem.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run into a case where that's actually happened? The error messages in the test cases generate correctly.
We could indeed get rid of the decorator altogether and handle everything directly in
__getitem__
. The reason for having the helper is that the decorator normally would not be able to handleslice
inputs and we would have to preprocess them. In any case, we do need to know which inputs areslice
s and which ones areint
s to determine things like distinguishing between a slice of size 1 (returns a list of length 1) and a single index (returns a single entry).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, the code that would be parsed would be the user's subscripted tensor, i.e., the subscript AST node, whose signature would not match that of our override for
__getitem__
regardless. We could pull in the portions of the code from the decorator that convert values into tensors, but it would not be much different in principle from what the decorator is already doing.