From 202f07c12ca4c49daa1b74c8fe162194c8c3322e Mon Sep 17 00:00:00 2001 From: ted chang Date: Mon, 26 Feb 2024 22:09:47 -0800 Subject: [PATCH] add sample unit test Signed-off-by: ted chang --- .github/workflows/test.yaml | 22 ++++++++++++ tests/test_data_type_utils.py | 64 +++++++++++++++++++++++++++++++++++ tox.ini | 11 +++++- 3 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test.yaml create mode 100644 tests/test_data_type_utils.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 000000000..1d8ee828e --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,22 @@ +name: Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install -r setup_requirements.txt + - name: Run unit tests + run: tox -e py \ No newline at end of file diff --git a/tests/test_data_type_utils.py b/tests/test_data_type_utils.py new file mode 100644 index 000000000..1b316464c --- /dev/null +++ b/tests/test_data_type_utils.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +import pytest +import torch + +# Local +from tuning.utils import data_type_utils + +dtype_dict = { + "bfloat16": torch.bfloat16, + "bits16": torch.bits16, + "bits1x8": torch.bits1x8, + "bits2x4": torch.bits2x4, + "bits4x2": torch.bits4x2, + "bits8": torch.bits8, + "bool": torch.bool, + "cdouble": torch.cdouble, + "cfloat": torch.cfloat, + "chalf": torch.chalf, + "complex128": torch.complex128, + "complex32": torch.complex32, + "complex64": torch.complex64, + "double": torch.double, + "float": torch.float, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e4m3fnuz": torch.float8_e4m3fnuz, + "float8_e5m2": torch.float8_e5m2, + "float8_e5m2fnuz": torch.float8_e5m2fnuz, + "half": torch.half, + "int": torch.int, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "int8": torch.int8, + "long": torch.long, + "qint32": torch.qint32, + "qint8": torch.qint8, + "quint2x4": torch.quint2x4, + "quint4x2": torch.quint4x2, + "quint8": torch.quint8, + "short": torch.short, + "uint8": torch.uint8, +} + + +def test_str_to_torch_dtype(): + for t in dtype_dict.keys(): + assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t) + + +def test_str_to_torch_dtype_exit(): + with pytest.raises(SystemExit): + data_type_utils.str_to_torch_dtype("foo") + + +def test_get_torch_dtype(): + for t in dtype_dict.keys(): + assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t) + assert data_type_utils.get_torch_dtype(dtype_dict.get(t)) == dtype_dict.get(t) diff --git a/tox.ini b/tox.ini index bbcbba9b0..41959ec9b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,14 @@ [tox] -envlist = lint, fmt +envlist = py, lint, fmt + +[testenv] +description = run unit tests +deps = + pytest>=7 + torch + transformers>=4.34.1 +commands = + pytest {posargs:tests} [testenv:fmt] description = format with pre-commit