Skip to content

Commit

Permalink
pre-commit: running and fixing...
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Mar 25, 2024
1 parent 7381e4e commit 292b663
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,7 @@ def mT_scalar_warning():
UserWarning,
)


@clangop(method_name="mT")
def matrix_transpose(a: TensorProxy) -> TensorProxy:
"""Transposes the last two dimensions of a tensor.
Expand All @@ -1191,7 +1192,6 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
[3, 6]])
"""


if a.ndim == 0:
mT_scalar_warning()
return a
Expand Down
5 changes: 2 additions & 3 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,13 +3956,12 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs
for shape in cases:
yield SampleInput(make(shape))


def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

# shape, error type, error message
cases = (
((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),
)
cases = (((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),)

for shape, err_type, err_msg in cases:
yield SampleInput(make(shape)), err_type, err_msg
Expand Down

0 comments on commit 292b663

Please sign in to comment.