Skip to content

Commit

Permalink
add tuning/config/configs unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: ted chang <[email protected]>
  • Loading branch information
tedhtchang committed Mar 8, 2024
1 parent 3f83a3d commit 17ba8e0
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/configs/test_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Third Party
import torch

# Local
import tuning.config.configs as c


def test_model_argument_configs():
da = c.DataArguments
ma = c.ModelArguments
ta = c.TrainingArguments
# test model arguments default
assert ma.model_name_or_path == "facebook/opt-125m"
assert ma.use_flash_attn == True
assert isinstance(ma.torch_dtype, torch.dtype)

# test data arguments default
assert da.data_path == None
assert da.response_template == None
assert da.dataset_text_field == None
assert da.validation_data_path == None

# test training arguments default
assert ta.cache_dir == None
assert ta.model_max_length == c.DEFAULT_CONTEXT_LENGTH
assert ta.packing == False


def test_model_argument_configs_init():
# new data arguments
da = c.DataArguments(
data_path="foo/bar",
response_template="\n### Label:",
dataset_text_field="output",
validation_data_path="/foo/bar",
)
assert da.data_path == "foo/bar"
assert da.response_template == "\n### Label:"
assert da.validation_data_path == "/foo/bar"

# new model arguments
ma = c.ModelArguments(
model_name_or_path="big/bloom", use_flash_attn=False, torch_dtype=torch.int32
)
assert ma.model_name_or_path == "big/bloom"
assert ma.use_flash_attn == False
assert ma.torch_dtype == torch.int32

# new training arguments
ta = c.TrainingArguments(
cache_dir="/tmp/cache",
model_max_length=1024,
packing=True,
output_dir="/tmp/output",
)
assert ta.cache_dir == "/tmp/cache"
assert ta.model_max_length == 1024
assert ta.packing == True
assert ta.output_dir == "/tmp/output"

0 comments on commit 17ba8e0

Please sign in to comment.