Skip to content

Commit

Permalink
[release/2.4] Skipped some inductor tests for no hipcc rocm environme…
Browse files Browse the repository at this point in the history
…nts (#1679)

Skipped some tests for wheels builds with a check for ROCM_HOME.
  • Loading branch information
iupaikov-amd authored Nov 13, 2024
1 parent 634b544 commit 31e58f8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
subtest,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
HAS_HIPCC,
)
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
Expand Down Expand Up @@ -751,6 +752,7 @@ def fn(a, b):
)

@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler")
def test_eager_aoti_cache_hit(self):
ns = "aten"
op_name = "abs"
Expand Down Expand Up @@ -803,6 +805,7 @@ def test_eager_aoti_cache_hit(self):
self.assertEqual(ref_value, res_value)

@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler")
def test_aoti_compile_with_persistent_cache(self):
def fn(a):
return torch.abs(a)
Expand Down Expand Up @@ -6661,6 +6664,7 @@ def fn(x):

self.common(fn, [torch.randn(64, 64)])

@unittest.skipIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler")
def test_new_cpp_build_logical(self):
from torch._inductor.codecache import validate_new_cpp_commands

Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from torch.testing._comparison import not_close_error_metas
from torch.testing._internal.common_dtype import get_all_dtypes
from torch.utils._import_utils import _check_module_exists
from torch.utils.cpp_extension import ROCM_HOME
import torch.utils._pytree as pytree

try:
Expand All @@ -102,9 +103,10 @@
except ImportError:
has_pytest = False


NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101")

HAS_HIPCC = torch.version.hip is not None and ROCM_HOME is not None and shutil.which('hipcc') is not None

def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)

Expand Down

0 comments on commit 31e58f8

Please sign in to comment.