Skip to content

Commit

Permalink
Added scalar warning.
Browse files Browse the repository at this point in the history
  • Loading branch information
apaz-cli committed Mar 24, 2024
1 parent e72dc5c commit 7381e4e
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from types import EllipsisType, NoneType
import copy
import time
import warnings

from thunder.core.baseutils import run_once
from thunder.core.compile_data import using_symbolic_values
from thunder.clang.langctx import register_method
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -1160,6 +1162,13 @@ def compute_broadcast_shape(*_shapes):
return tuple(common_shape)


@run_once
def mT_scalar_warning():
warnings.warn(
"Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.",
UserWarning,
)

@clangop(method_name="mT")
def matrix_transpose(a: TensorProxy) -> TensorProxy:
"""Transposes the last two dimensions of a tensor.
Expand All @@ -1184,6 +1193,7 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:


if a.ndim == 0:
mT_scalar_warning()
return a
elif a.ndim == 1:
raise RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.")
Expand Down

0 comments on commit 7381e4e

Please sign in to comment.