diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6f6fd1987ebe1..28d50d2eb7139 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -85,6 +85,7 @@ subtest, TEST_WITH_ASAN, TEST_WITH_ROCM, + HAS_HIPCC, ) from torch.utils import _pytree as pytree from torch.utils._python_dispatch import TorchDispatchMode @@ -751,6 +752,7 @@ def fn(a, b): ) @skipCUDAIf(not SM80OrLater, "Requires sm80") + @skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler") def test_eager_aoti_cache_hit(self): ns = "aten" op_name = "abs" @@ -803,6 +805,7 @@ def test_eager_aoti_cache_hit(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") + @skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler") def test_aoti_compile_with_persistent_cache(self): def fn(a): return torch.abs(a) @@ -6661,6 +6664,7 @@ def fn(x): self.common(fn, [torch.randn(64, 64)]) + @unittest.skipIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler") def test_new_cpp_build_logical(self): from torch._inductor.codecache import validate_new_cpp_commands diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index fd050c1a4e3a4..9d80e9ef5a139 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -94,6 +94,7 @@ from torch.testing._comparison import not_close_error_metas from torch.testing._internal.common_dtype import get_all_dtypes from torch.utils._import_utils import _check_module_exists +from torch.utils.cpp_extension import ROCM_HOME import torch.utils._pytree as pytree try: @@ -102,9 +103,10 @@ except ImportError: has_pytest = False - NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101") +HAS_HIPCC = torch.version.hip is not None and ROCM_HOME is not None and shutil.which('hipcc') is not None + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs)