Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint tests #112

Merged
merged 4 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions tests/build/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@


# Note: job_config dict gets modified during process_launch_training_args
@pytest.fixture(scope="session")
def job_config():
@pytest.fixture(name="job_config", scope="session")
def fixture_job_config():
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
dummy_job_config_dict = json.load(f)
return dummy_job_config_dict
Expand All @@ -50,35 +50,35 @@ def test_process_launch_training_args(job_config):
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
assert training_args.output_dir == "bloom-twitter"
assert tune_config == None
assert merge_model == False
assert tune_config is None
assert merge_model is False


def test_process_launch_training_args_defaults(job_config):
job_config_defaults = copy.deepcopy(job_config)
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] == False
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _ = process_launch_training_args(
job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn == False
assert model_args.use_flash_attn is False
assert training_args.save_strategy.value == "epoch"


def test_process_launch_training_args_peft_method(job_config):
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt)
assert type(tune_config) == PromptTuningConfig
assert merge_model == False
assert isinstance(tune_config, PromptTuningConfig)
assert merge_model is False

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
assert type(tune_config) == LoraConfig
assert merge_model == True
assert isinstance(tune_config, LoraConfig)
assert merge_model is True


def test_process_accelerate_launch_args(job_config):
Expand Down Expand Up @@ -147,8 +147,8 @@ def test_process_accelerate_launch_custom_config_file(patch_path_exists):
assert args.config_file == dummy_config_path
assert args.num_processes is None

# When user passes custom fsdp config file and also `num_processes` as a param, use custom config and
# overwrite num_processes from config with param
# When user passes custom fsdp config file and also `num_processes` as a param,
# use custom config and overwrite num_processes from config with param
temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}}
args = process_accelerate_launch_args(temp_job_config)
assert args.config_file == dummy_config_path
8 changes: 4 additions & 4 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_helper_causal_lm_train_kwargs():
)

assert model_args.model_name_or_path == MODEL_NAME
assert model_args.use_flash_attn == False
assert model_args.use_flash_attn is False
assert model_args.torch_dtype == "float32"

assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA
Expand Down Expand Up @@ -295,10 +295,10 @@ def _validate_training(tempdir, check_eval=False):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/training_logs.jsonl".format(tempdir)
train_log_contents = ""
with open(train_logs_file_path) as f:
with open(train_logs_file_path, encoding="utf-8") as f:
train_log_contents = f.read()

assert os.path.exists(train_logs_file_path) == True
assert os.path.exists(train_logs_file_path) is True
assert os.path.getsize(train_logs_file_path) > 0
assert "training_loss" in train_log_contents

Expand All @@ -311,7 +311,7 @@ def _get_checkpoint_path(dir_path):


def _get_adapter_config(dir_path):
with open(os.path.join(dir_path, "adapter_config.json")) as f:
with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f:
return json.load(f)


Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_data_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def test_str_to_torch_dtype():
for t in dtype_dict.keys():
for t in dtype_dict:
assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t)


Expand All @@ -42,7 +42,7 @@ def test_str_to_torch_dtype_exit():


def test_get_torch_dtype():
for t in dtype_dict.keys():
for t in dtype_dict:
# When passed a string, it gets converted to torch.dtype
assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t)
# When passed a torch.dtype, we get the same torch.dtype returned
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ allowlist_externals = ./scripts/fmt.sh
[testenv:lint]
description = lint with pylint
deps = pylint>=2.16.2,<=3.1.0
pytest
-r requirements.txt
commands = pylint tuning scripts/*.py build/*.py
commands = pylint tuning scripts/*.py build/*.py tests
allowlist_externals = pylint
Loading