Skip to content

Commit

Permalink
Lint tests (#112)
Browse files Browse the repository at this point in the history
* Pylint on tests

Signed-off-by: Thara Palanivel <[email protected]>

* Fix lint

Signed-off-by: Thara Palanivel <[email protected]>

---------

Signed-off-by: Thara Palanivel <[email protected]>
  • Loading branch information
tharapalanivel authored Apr 10, 2024
1 parent b04db32 commit 56a6a8d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
24 changes: 12 additions & 12 deletions tests/build/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@


# Note: job_config dict gets modified during process_launch_training_args
@pytest.fixture(scope="session")
def job_config():
@pytest.fixture(name="job_config", scope="session")
def fixture_job_config():
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
dummy_job_config_dict = json.load(f)
return dummy_job_config_dict
Expand All @@ -50,35 +50,35 @@ def test_process_launch_training_args(job_config):
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
assert training_args.output_dir == "bloom-twitter"
assert tune_config == None
assert merge_model == False
assert tune_config is None
assert merge_model is False


def test_process_launch_training_args_defaults(job_config):
job_config_defaults = copy.deepcopy(job_config)
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] == False
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _ = process_launch_training_args(
job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn == False
assert model_args.use_flash_attn is False
assert training_args.save_strategy.value == "epoch"


def test_process_launch_training_args_peft_method(job_config):
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt)
assert type(tune_config) == PromptTuningConfig
assert merge_model == False
assert isinstance(tune_config, PromptTuningConfig)
assert merge_model is False

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
assert type(tune_config) == LoraConfig
assert merge_model == True
assert isinstance(tune_config, LoraConfig)
assert merge_model is True


def test_process_accelerate_launch_args(job_config):
Expand Down Expand Up @@ -147,8 +147,8 @@ def test_process_accelerate_launch_custom_config_file(patch_path_exists):
assert args.config_file == dummy_config_path
assert args.num_processes is None

# When user passes custom fsdp config file and also `num_processes` as a param, use custom config and
# overwrite num_processes from config with param
# When user passes custom fsdp config file and also `num_processes` as a param,
# use custom config and overwrite num_processes from config with param
temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}}
args = process_accelerate_launch_args(temp_job_config)
assert args.config_file == dummy_config_path
8 changes: 4 additions & 4 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_helper_causal_lm_train_kwargs():
)

assert model_args.model_name_or_path == MODEL_NAME
assert model_args.use_flash_attn == False
assert model_args.use_flash_attn is False
assert model_args.torch_dtype == "float32"

assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA
Expand Down Expand Up @@ -295,10 +295,10 @@ def _validate_training(tempdir, check_eval=False):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/training_logs.jsonl".format(tempdir)
train_log_contents = ""
with open(train_logs_file_path) as f:
with open(train_logs_file_path, encoding="utf-8") as f:
train_log_contents = f.read()

assert os.path.exists(train_logs_file_path) == True
assert os.path.exists(train_logs_file_path) is True
assert os.path.getsize(train_logs_file_path) > 0
assert "training_loss" in train_log_contents

Expand All @@ -311,7 +311,7 @@ def _get_checkpoint_path(dir_path):


def _get_adapter_config(dir_path):
with open(os.path.join(dir_path, "adapter_config.json")) as f:
with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f:
return json.load(f)


Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_data_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def test_str_to_torch_dtype():
for t in dtype_dict.keys():
for t in dtype_dict:
assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t)


Expand All @@ -42,7 +42,7 @@ def test_str_to_torch_dtype_exit():


def test_get_torch_dtype():
for t in dtype_dict.keys():
for t in dtype_dict:
# When passed a string, it gets converted to torch.dtype
assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t)
# When passed a torch.dtype, we get the same torch.dtype returned
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ allowlist_externals = ./scripts/fmt.sh
[testenv:lint]
description = lint with pylint
deps = pylint>=2.16.2,<=3.1.0
pytest
-r requirements.txt
commands = pylint tuning scripts/*.py build/*.py
commands = pylint tuning scripts/*.py build/*.py tests
allowlist_externals = pylint

0 comments on commit 56a6a8d

Please sign in to comment.