Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Dec 3, 2024
1 parent 6f828ce commit a8cc992
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 18 deletions.
19 changes: 17 additions & 2 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
TOKENIZER_PATHS,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestFullFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -169,6 +175,15 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
checkpoint_files = [
os.path.join(
tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
Expand All @@ -177,8 +192,8 @@ def test_training_state_on_resume(
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{tmpdir}' \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.checkpoint_files={checkpoint_files}\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
Expand Down
19 changes: 17 additions & 2 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
TOKENIZER_PATHS,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestFullFinetuneSingleDeviceRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -173,15 +179,24 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
checkpoint_files = [
os.path.join(
tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
cmd_2 = f"""
tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
batch_size=8 \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.checkpoint_files={checkpoint_files}\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
Expand Down
19 changes: 17 additions & 2 deletions tests/recipes/test_knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestKDDistributedRecipe:
def _get_test_config_overrides(self, epochs: int = 2):
Expand Down Expand Up @@ -146,15 +152,24 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
checkpoint_files = [
os.path.join(
tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
--config llama3_2/knowledge_distillation_distributed \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, "adapter.bin")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down
19 changes: 17 additions & 2 deletions tests/recipes/test_knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestKDSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
Expand Down Expand Up @@ -184,15 +190,24 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
checkpoint_files = [
os.path.join(
tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
cmd_2 = f"""
tune run knowledge_distillation_single_device \
--config qwen2/knowledge_distillation_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, "adapter.bin")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
Expand Down
12 changes: 10 additions & 2 deletions tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
)


class TestLoRADPOSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
Expand Down Expand Up @@ -99,7 +104,10 @@ def test_training_state_on_resume(

resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run lora_dpo_single_device \
--config llama2/7B_lora_dpo_single_device \
Expand All @@ -109,8 +117,8 @@ def test_training_state_on_resume(
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, "adapter.bin")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
Expand Down
11 changes: 9 additions & 2 deletions tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
)


class TestLoRAFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -169,6 +174,8 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \
--config {config} \
Expand All @@ -180,8 +187,8 @@ def test_training_state_on_resume(
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, "adapter.bin")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
Expand Down
11 changes: 9 additions & 2 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
)


class TestLoRAFinetuneSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
Expand Down Expand Up @@ -232,6 +237,8 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run lora_finetune_single_device \
--config llama2/7B_lora_single_device \
Expand All @@ -241,8 +248,8 @@ def test_training_state_on_resume(
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, "adapter.bin")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
Expand Down
38 changes: 34 additions & 4 deletions tests/recipes/test_ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
mps_ignored_test,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestPPOFullFinetuneSingleDeviceRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -210,26 +216,50 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
# Resume training at step 2
resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)

epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
policy_checkpoint_files = [
os.path.join(
policy_tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
value_checkpoint_files = [
os.path.join(
value_tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
rwd_checkpoint_files = [
os.path.join(
value_tmpdir,
epoch_folder_minus_one,
SHARD_FNAME.format(cpt_idx=1, num_shards=1),
)
]
cmd_2 = f"""
tune run ppo_full_finetune_single_device \
--config mistral/7B_full_ppo_low_memory \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{policy_tmpdir}' \
checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\
checkpointer.checkpoint_files={policy_checkpoint_files}\
checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={policy_tmpdir} \
checkpointer.model_type=LLAMA2 \
ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \
ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\
value_checkpointer.checkpoint_dir='{value_tmpdir}' \
value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, "hf_model_0001_0.pt")}]\
value_checkpointer.checkpoint_files={value_checkpoint_files}\
value_checkpointer.output_dir={value_tmpdir} \
reward_checkpointer.checkpoint_dir='{ckpt_dir}' \
reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\
reward_checkpointer.checkpoint_files={rwd_checkpoint_files}\
resume_from_checkpoint=True \
metric_logger._component_=torchtune.training.metric_logging.DiskLogger \
Expand Down

0 comments on commit a8cc992

Please sign in to comment.