Skip to content

Commit

Permalink
add sample unit test
Browse files Browse the repository at this point in the history
Signed-off-by: ted chang <[email protected]>
  • Loading branch information
tedhtchang committed Feb 27, 2024
1 parent f4e8eb4 commit 202f07c
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
22 changes: 22 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install -r setup_requirements.txt
- name: Run unit tests
run: tox -e py
64 changes: 64 additions & 0 deletions tests/test_data_type_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
import pytest
import torch

# Local
from tuning.utils import data_type_utils

dtype_dict = {
"bfloat16": torch.bfloat16,
"bits16": torch.bits16,
"bits1x8": torch.bits1x8,
"bits2x4": torch.bits2x4,
"bits4x2": torch.bits4x2,
"bits8": torch.bits8,
"bool": torch.bool,
"cdouble": torch.cdouble,
"cfloat": torch.cfloat,
"chalf": torch.chalf,
"complex128": torch.complex128,
"complex32": torch.complex32,
"complex64": torch.complex64,
"double": torch.double,
"float": torch.float,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e4m3fnuz": torch.float8_e4m3fnuz,
"float8_e5m2": torch.float8_e5m2,
"float8_e5m2fnuz": torch.float8_e5m2fnuz,
"half": torch.half,
"int": torch.int,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"int8": torch.int8,
"long": torch.long,
"qint32": torch.qint32,
"qint8": torch.qint8,
"quint2x4": torch.quint2x4,
"quint4x2": torch.quint4x2,
"quint8": torch.quint8,
"short": torch.short,
"uint8": torch.uint8,
}


def test_str_to_torch_dtype():
for t in dtype_dict.keys():
assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t)


def test_str_to_torch_dtype_exit():
with pytest.raises(SystemExit):
data_type_utils.str_to_torch_dtype("foo")


def test_get_torch_dtype():
for t in dtype_dict.keys():
assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t)
assert data_type_utils.get_torch_dtype(dtype_dict.get(t)) == dtype_dict.get(t)
11 changes: 10 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
[tox]
envlist = lint, fmt
envlist = py, lint, fmt

[testenv]
description = run unit tests
deps =
pytest>=7
torch
transformers>=4.34.1
commands =
pytest {posargs:tests}

[testenv:fmt]
description = format with pre-commit
Expand Down

0 comments on commit 202f07c

Please sign in to comment.