diff --git a/tests/configs/test_configs.py b/tests/configs/test_configs.py new file mode 100644 index 000000000..53aa3dd22 --- /dev/null +++ b/tests/configs/test_configs.py @@ -0,0 +1,59 @@ +# Third Party +import torch + +# Local +import tuning.config.configs as c + + +def test_model_argument_configs(): + da = c.DataArguments + ma = c.ModelArguments + ta = c.TrainingArguments + # test model arguments default + assert ma.model_name_or_path == "facebook/opt-125m" + assert ma.use_flash_attn == True + assert isinstance(ma.torch_dtype, torch.dtype) + + # test data arguments default + assert da.data_path == None + assert da.response_template == None + assert da.dataset_text_field == None + assert da.validation_data_path == None + + # test training arguments default + assert ta.cache_dir == None + assert ta.model_max_length == c.DEFAULT_CONTEXT_LENGTH + assert ta.packing == False + + +def test_model_argument_configs_init(): + # new data arguments + da = c.DataArguments( + data_path="foo/bar", + response_template="\n### Label:", + dataset_text_field="output", + validation_data_path="/foo/bar", + ) + assert da.data_path == "foo/bar" + assert da.response_template == "\n### Label:" + assert da.validation_data_path == "/foo/bar" + + # new model arguments + ma = c.ModelArguments( + model_name_or_path="big/bloom", use_flash_attn=False, torch_dtype=torch.int32 + ) + assert ma.model_name_or_path == "big/bloom" + assert ma.use_flash_attn == False + assert ma.torch_dtype == torch.int32 + + # new training arguments + ta = c.TrainingArguments( + cache_dir="/tmp/cache", + model_max_length=1024, + packing=True, + output_dir="/tmp/output", + ) + assert ta.cache_dir == "/tmp/cache" + assert ta.model_max_length == 1024 + assert ta.packing == True + assert ta.output_dir == "/tmp/output"