From 8693b27adcc8fd2c92111f630f5543bf814613e4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Mar 2024 10:35:32 -0400 Subject: [PATCH] [Testing] Implementation of an updated testing framework (#2187) * update unit tests and existing architecture to use updated framework * small update * updat custom integration teseting, adding support for custom sripts and classes * add docstring * add dummy smoketest * clean-up * typo fix * move examples to examples folder; add doc explaining updated framework * add testing targets * add enums for test type and candence, additional config checks, remove extra assert True, and log message for custom test failures * fix typo; update doc * Update tests/docs/testing_framework.md Co-authored-by: Rahul Tuli --------- Co-authored-by: Rahul Tuli --- pyproject.toml | 8 +- setup.py | 1 + tests/custom_test.py | 86 ++++++++++ tests/data.py | 40 +++++ tests/docs/testing_framework.md | 98 +++++++++++ tests/examples/__init__.py | 13 ++ .../generation_configs/custom_class/run.py | 21 +++ .../generation_configs/custom_class/run.yaml | 3 + .../custom_script/test_python_script.py | 20 +++ .../custom_script/test_python_script.yaml | 3 + tests/examples/test_integration_custom.py | 40 +++++ tests/sparseml/export/test_export_data_new.py | 153 ++++++++++++++++++ .../generation_configs/tiny_stories.yaml | 4 + .../transformers/test_generation_export.py | 58 +++++++ tests/testing_utils.py | 109 +++++++++++++ 15 files changed, 656 insertions(+), 1 deletion(-) create mode 100644 tests/custom_test.py create mode 100644 tests/data.py create mode 100644 tests/docs/testing_framework.md create mode 100644 tests/examples/__init__.py create mode 100644 tests/examples/generation_configs/custom_class/run.py create mode 100644 tests/examples/generation_configs/custom_class/run.yaml create mode 100644 tests/examples/generation_configs/custom_script/test_python_script.py create mode 100644 tests/examples/generation_configs/custom_script/test_python_script.yaml create mode 100644 tests/examples/test_integration_custom.py create mode 100644 tests/sparseml/export/test_export_data_new.py create mode 100644 tests/sparseml/export/transformers/generation_configs/tiny_stories.yaml create mode 100644 tests/sparseml/export/transformers/test_generation_export.py create mode 100644 tests/testing_utils.py diff --git a/pyproject.toml b/pyproject.toml index af9fccdb6f5..37850262b95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,10 @@ line-length = 88 target-version = ['py36'] [tool.pytest.ini_options] -tmp_path_retention_policy = "none" \ No newline at end of file +tmp_path_retention_policy = "none" +markers = [ + "integration: integration tests", + "unit: unit tests", + "custom: custom integration tests", + "smoke: smoke tests" +] diff --git a/setup.py b/setup.py index a48b754881b..4af84f957b0 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ "tensorboard>=1.0,<2.9", "tensorboardX>=1.0", "evaluate>=0.4.1", + "parameterized", ] _docs_deps = [ diff --git a/tests/custom_test.py b/tests/custom_test.py new file mode 100644 index 00000000000..2fe4cdd9368 --- /dev/null +++ b/tests/custom_test.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import runpy +import unittest +from typing import Optional + +from tests.data import CustomTestConfig + + +_LOGGER = logging.getLogger(__name__) + + +class CustomTestCase(unittest.TestCase): + """ + CustomTestCase class. All custom test classes written should inherit from this + class. They will be subsequently tested in the test_custom_class function defined + within the CustomIntegrationTest. + """ + + ... + + +# TODO: consider breaking this up into two classes, similar to non-custom +# integration tests. Could then make use of parameterize_class instead +class CustomIntegrationTest(unittest.TestCase): + """ + Base Class for all custom integration tests. + """ + + custom_scripts_directory: str = None + custom_class_directory: str = None + + def test_custom_scripts(self, config: Optional[CustomTestConfig] = None): + """ + This test case will run all custom python scripts that reside in the directory + defined by custom_scripts_directory. For each custom python script, there + should be a corresponding yaml file which consists of the values defined by + the dataclass CustomTestConfig, including the field script_path which is + populated with the name of the python script. The test will fail if any + of the defined assertions in the script fail + + :param config: config defined by the CustomTestConfig dataclass + + """ + if config is None: + self.skipTest("No custom scripts found. Testing test") + script_path = f"{self.custom_scripts_directory}/{config.script_path}" + runpy.run_path(script_path) + + def test_custom_class(self, config: Optional[CustomTestConfig] = None): + """ + This test case will run all custom test classes that reside in the directory + defined by custom_class_directory. For each custom test class, there + should be a corresponding yaml file which consists of the values defined by + the dataclass CustomTestConfig, including the field script_path which is + populated with the name of the python script. The test will fail if any + of the defined tests in the custom class fail. + + :param config: config defined by the CustomTestConfig dataclass + + """ + if config is None: + self.skipTest("No custom class found. Testing test") + loader = unittest.TestLoader() + tests = loader.discover(self.custom_class_directory, pattern=config.script_path) + testRunner = unittest.runner.TextTestRunner() + output = testRunner.run(tests) + for out in output.errors: + raise Exception(output[-1]) + + for out in output.failures: + _LOGGER.error(out[-1]) + assert False diff --git a/tests/data.py b/tests/data.py new file mode 100644 index 00000000000..f8e284d90c9 --- /dev/null +++ b/tests/data.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + + +# TODO: maybe test type as decorators? +class TestType(Enum): + SANITY = "sanity" + REGRESSION = "regression" + SMOKE = "smoke" + + +class Cadence(Enum): + COMMIT = "commit" + WEEKLY = "weekly" + NIGHTLY = "nightly" + + +@dataclass +class TestConfig: + test_type: TestType + cadence: Cadence + + +@dataclass +class CustomTestConfig(TestConfig): + script_path: str diff --git a/tests/docs/testing_framework.md b/tests/docs/testing_framework.md new file mode 100644 index 00000000000..8700bfd426f --- /dev/null +++ b/tests/docs/testing_framework.md @@ -0,0 +1,98 @@ +# An Updated Testing Framework + +Below is a summary of the testing framework proposed for sparseml. + +## Existing Tests + +### Integration Tests + +Existing integration tests are rewritten such that all values relevant to the particular +test case are read from a config file, as opposed to hardcoded values in the test case +itself or through overloaded pytest fixtures. Each config file should include one +combination of relevant parameters that are needed to be tested for that particular +integration test. Each config file must at least have the values defined by the +`TestConfig` dataclass found under `tests/data`. These values include the `cadence` +(weekly, commit, or nightly) and the `test_type` (sanity, smoke, or regression) for the +particular test case. While the `test_type` is currently using a config value, we can +expand it to use pytest markers instead. An example of this updated approach can be +found in the export test case, `test_generation_export.py` + +### Unit Tests + +Unit tests are not changed significantly however, can be adapted to use the +`unittest.TestCase` base class. While this is not necessary to be used, it does +seem like `unittest` provides overall greater readability compared to normal pytest +tests. There is also the case where we can use both pytest and unittest for our test +cases. This is not uncommon and also what transformers currently does. An example of +an updated test can be in the `test_export_data_new.py` test file. A note about using +`unittest` is that it requires us to install the `parameterized` package for +decorating test cases. + +## Custom Testing + +For the purpose of custom integration testing, two new workflows are now enabled + +1. **Custom Script Testing**: Users can test their custom python script which is not +required to follow any specific structure. All asserts in the script will be validated +2. **Custom Testing Class**: For slightly more structure, users can write their own +testing class. The only requirement is that this testing class inherits from the base +class `CustomTestCase` which can be found under `tests/custom_test`. + +To enable custom integration testing for any of the cases above, a test class must be +written which inherits from `CustomIntegrationTest` under tests/custom_test. Within this +class, two paths can be defined: `custom_scripts_directory` which points to the +directory containing all the custom scripts which are to be tested and +`custom_class_directory` which points to the directory containing all the custom test +classes. + +Similar to the non-custom integration testing, each custom integration test script or +test class must include a .yaml file which includes 3 values +(defined by the `CustomTestConfig` dataclass found under `tests/data`): +`test_type` which indicates if the test is a sanity, smoke or regression test, +`cadence`: which dictates how often the test case runs (on commit, weekly, or nightly), +and the `script_path` which lists the name of the custom testing script or custom test +class within the directory. + +An example of an implementation of the `CustomIntegrationTest` can be found under +`tests/examples` + +## Additional markers and decorators + +- New markers are added in to mark tests as `unit`, `integration`, `smoke`, and `custom` +tests allowing us to run a subset of the tests when needed +- Two new decorators are added in to check for package and compute requirements. If +the requirements are not met, the test is skipped. Currently, `requires_torch` and +`requires_gpu` are added in and can be found under `testing_utils.py` + +## Testing Targets + +### Unit Testing Targets: +- A unit test should be written for every utils, helper, or static function. + - Test cases should be written for all datatype combinations that the function takes as input + - Can have `smoke` tests but focus should be on `sanity` + +### Integration Testing Targets: +- An integration test should be written for every cli pathway that is exposed through `setup.py` + - All cli-arg combinations should be tested through a `smoke` check + (all may be overkill but ideally we're covering beyond the few important combinations) + - All **important** cli-arg combinations should be covered through either a `sanity` + check or a `regression` check + - A small model should be tested through a `sanity` check + - All other larger models should be tested through `regression` test types + +- An integration test should be written for every major/critical module + - All arg combinations should be tested through a `smoke` check + (all may be overkill but ideally we're covering beyond the few important combinations) + - All **important** arg combinations should be covered through either a `sanity` + check or a `regression` check + - A small model should be tested through a `sanity` check + - All other larger models should be tested through `regression` test types + +## End-to-end Testing Targets: +- Tests cascading repositories (sparseml --> vLLM) but will become more prominent as our +docker containers are furhter solidified. Goal would be to emulate common flows users +may follow + +## Cadence +- Ideally, large models and `regression` tests should be tested on a nightly cadence while +unit tests and `sanity` test should be tested on a per commit basis \ No newline at end of file diff --git a/tests/examples/__init__.py b/tests/examples/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/examples/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/examples/generation_configs/custom_class/run.py b/tests/examples/generation_configs/custom_class/run.py new file mode 100644 index 00000000000..3b7a612f413 --- /dev/null +++ b/tests/examples/generation_configs/custom_class/run.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.custom_test import CustomTestCase + + +# Example custom class for testing +class MyTests(CustomTestCase): + def test_something_else(self): + assert 1 == 1 diff --git a/tests/examples/generation_configs/custom_class/run.yaml b/tests/examples/generation_configs/custom_class/run.yaml new file mode 100644 index 00000000000..b53e8b2ec09 --- /dev/null +++ b/tests/examples/generation_configs/custom_class/run.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "sanity" +script_path: "run.py" diff --git a/tests/examples/generation_configs/custom_script/test_python_script.py b/tests/examples/generation_configs/custom_script/test_python_script.py new file mode 100644 index 00000000000..284365da6f5 --- /dev/null +++ b/tests/examples/generation_configs/custom_script/test_python_script.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example custom script for testing +def do_something(): + assert 1 == 1 + + +do_something() diff --git a/tests/examples/generation_configs/custom_script/test_python_script.yaml b/tests/examples/generation_configs/custom_script/test_python_script.yaml new file mode 100644 index 00000000000..c4e976c8e39 --- /dev/null +++ b/tests/examples/generation_configs/custom_script/test_python_script.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "sanity" +script_path: "test_python_script.py" diff --git a/tests/examples/test_integration_custom.py b/tests/examples/test_integration_custom.py new file mode 100644 index 00000000000..9986227b647 --- /dev/null +++ b/tests/examples/test_integration_custom.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import pytest + +from parameterized import parameterized +from tests.custom_test import CustomIntegrationTest +from tests.data import CustomTestConfig +from tests.testing_utils import parse_params + + +@pytest.mark.custom +class TestExampleIntegrationCustom(CustomIntegrationTest): + """ + Integration test class which uses the base CustomIntegrationTest class. + """ + + custom_scripts_directory = "tests/examples/generation_configs/custom_script" + custom_class_directory = "tests/examples/generation_configs/custom_class" + + @parameterized.expand(parse_params(custom_scripts_directory, type="custom")) + def test_custom_scripts(self, config: Optional[CustomTestConfig] = None): + super().test_custom_scripts(config) + + @parameterized.expand(parse_params(custom_class_directory, type="custom")) + def test_custom_class(self, config: Optional[CustomTestConfig] = None): + super().test_custom_class(config) diff --git a/tests/sparseml/export/test_export_data_new.py b/tests/sparseml/export/test_export_data_new.py new file mode 100644 index 00000000000..bcad74ff4fa --- /dev/null +++ b/tests/sparseml/export/test_export_data_new.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tarfile +import unittest +from enum import Enum +from pathlib import Path + +import pytest + +from parameterized import parameterized +from sparseml.export.export_data import create_data_samples, export_data_sample +from tests.sparseml.export.utils import get_dummy_dataset +from tests.testing_utils import requires_torch + + +# NOTE: These tests are equivalent to the tests in test_export_data, updated to use +# the new framework + + +@requires_torch +@pytest.mark.unit +class ExportDataTransformersUnitTest(unittest.TestCase): + def setUp(self): + import torch + from torch.utils.data import DataLoader + + class Identity(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.device = self.dummy_param.device + + def forward(self, input_ids, attention_mask): + return dict(input_ids=input_ids, attention_mask=attention_mask) + + self.identity_model = Identity() + self.data_loader = DataLoader(get_dummy_dataset("transformers"), batch_size=1) + + @parameterized.expand( + [[0, True], [0, False], [1, True], [1, False], [5, True], [5, False]] + ) + def test_create_data_samples(self, num_samples, model): + import torch + + model = self.identity_model.to("cpu") if model else None + + inputs, outputs, labels = create_data_samples( + data_loader=self.data_loader, num_samples=num_samples, model=model + ) + target_input = next(iter(self.data_loader)) + target_output = target_input + + self.assertEqual(len(inputs), num_samples) + for input in inputs: + for key, value in input.items(): + assert torch.equal(value.unsqueeze(0), target_input[key]) + assert labels == [] + if model is not None: + assert len(outputs) == num_samples + for output in outputs: + for key, value in output.items(): + assert torch.equal(value, target_output[key][0]) + + def tearDown(self): + pass + + +@requires_torch +@pytest.mark.unit +class ExportGenericDataUnitTest(unittest.TestCase): + def setUp(self): + import torch + + class LabelNames(Enum): + basename = "sample-dummies" + filename = "dummy" + + num_samples = 5 + batch_size = 3 + self.samples = [ + torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples) + ] + self.names = LabelNames + self.tmp_path = Path("tmp") + self.tmp_path.mkdir(exist_ok=True) + + @parameterized.expand([[True], [False]]) + def test_export_data_sample(self, as_tar): + export_data_sample( + samples=self.samples, + names=self.names, + target_path=self.tmp_path, + as_tar=as_tar, + ) + + dir_name = self.names.basename.value + dir_name_tar = self.names.basename.value + ".tar.gz" + + if as_tar: + with tarfile.open(os.path.join(self.tmp_path, dir_name_tar)) as tar: + tar.extractall(path=self.tmp_path) + + assert ( + set(os.listdir(self.tmp_path)) == {dir_name} + if not as_tar + else {dir_name, dir_name_tar} + ) + assert set(os.listdir(os.path.join(self.tmp_path, "sample-dummies"))) == { + f"dummy-000{i}.npz" for i in range(len(self.samples)) + } + + def tearDown(self): + shutil.rmtree(self.tmp_path) + + +# NOTE: Dummy smoke test + + +@pytest.mark.smoke +@requires_torch +class ExportDataDummySmokeTest(unittest.TestCase): + def setUp(self): + import torch + + self.samples = [torch.randn(1, 3, 224, 224) for _ in range(2)] + + class LabelNames(Enum): + basename = "sample-dummies" + filename = "dummy" + + self.names = LabelNames + + @parameterized.expand([["some_path"], [Path("some_path")]]) + def test_export_runs(self, target_path): + Path(target_path).mkdir(exist_ok=True) + export_data_sample( + samples=self.samples, names=self.names, target_path=target_path + ) + shutil.rmtree(target_path) diff --git a/tests/sparseml/export/transformers/generation_configs/tiny_stories.yaml b/tests/sparseml/export/transformers/generation_configs/tiny_stories.yaml new file mode 100644 index 00000000000..9ba62a3cced --- /dev/null +++ b/tests/sparseml/export/transformers/generation_configs/tiny_stories.yaml @@ -0,0 +1,4 @@ +cadence: "commit" +test_type: "sanity" +stub: "roneneldan/TinyStories-1M" +task: text-generation diff --git a/tests/sparseml/export/transformers/test_generation_export.py b/tests/sparseml/export/transformers/test_generation_export.py new file mode 100644 index 00000000000..f92778382da --- /dev/null +++ b/tests/sparseml/export/transformers/test_generation_export.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest +from pathlib import Path + +import pytest + +from huggingface_hub import snapshot_download +from parameterized import parameterized_class +from sparseml import export +from tests.testing_utils import parse_params + + +CONFIGS_DIRECTORY = "tests/sparseml/export/transformers/generation_configs" + +# NOTE: this integration test class has the same integration test written in +# test_geneeration_export, updated to use the new framework + + +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestGenerationExportIntegration(unittest.TestCase): + stub = None + task = None + + def setUp(self): + self.tmp_path = Path("tmp") + self.tmp_path.mkdir(exist_ok=True) + + model_path = self.tmp_path / "model" + self.target_path = self.tmp_path / "target" + self.source_path = snapshot_download(self.stub, local_dir=model_path) + + def test_export_with_external_data(self): + export( + source_path=self.source_path, + target_path=self.target_path, + task=self.task, + save_with_external_data=True, + ) + assert (self.target_path / "deployment" / "model.onnx").exists() + assert (self.target_path / "deployment" / "model.data").exists() + + def tearDown(self): + shutil.rmtree(self.tmp_path) diff --git a/tests/testing_utils.py b/tests/testing_utils.py new file mode 100644 index 00000000000..81853d0ca03 --- /dev/null +++ b/tests/testing_utils.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import enum +import logging +import os +import unittest +from typing import List, Optional, Union + +import yaml + +from tests.data import CustomTestConfig, TestConfig + + +# TODO: probably makes sense to move this type of function to a more central place, +# which can be used by __init__.py as well +def is_torch_available(): + try: + import torch # noqa: F401 + + return True + except ImportError: + return False + + +def is_gpu_available(): + return False + + +def requires_torch(test_case): + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def requires_gpu(test_case): + return unittest.skipUnless(is_gpu_available(), "test requires GPU")(test_case) + + +def _load_yaml(configs_directory, file): + if file.endswith(".yaml") or file.endswith(".yml"): + config_path = os.path.join(configs_directory, file) + # reads the yaml file + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + return None + + +def _validate_test_config(config: dict): + for f in dataclasses.fields(TestConfig): + if f.name not in config: + return False + config_value = config.get(f.name) + if issubclass(f.type, enum.Enum): + try: + f.type(config_value) + except ValueError: + raise False + return True + + +# Set cadence in the config. The environment must set if nightly, weekly or commit +# tests are running +def parse_params( + configs_directory: str, type: Optional[str] = None +) -> List[Union[dict, CustomTestConfig]]: + # parses the config file provided + assert os.path.isdir( + configs_directory + ), f"Config_directory {configs_directory} is not a directory" + + config_dicts = [] + for file in os.listdir(configs_directory): + config = _load_yaml(configs_directory, file) + if not config: + continue + + cadence = os.environ.get("CADENCE", "commit") + expected_cadence = config.get("cadence") + + if not isinstance(expected_cadence, list): + expected_cadence = [expected_cadence] + if cadence in expected_cadence: + if type == "custom": + config = CustomTestConfig(**config) + else: + if not _validate_test_config(config): + raise ValueError( + "The config provided does not comply with the expected " + "structure. See tests.data.TestConfig for the expected " + "fields." + ) + config_dicts.append(config) + else: + logging.info( + f"Skipping testing model: {file} for cadence: {config['cadence']}" + ) + return config_dicts