Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let Huggingface Properly Initialize Arguments, and Fix FSDP-LORA Checkpoint-Saves and Resumption #53

Merged
merged 6 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)

def __post_init__(self):
super().__post_init__()
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 2 additions & 22 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@
from tuning.utils.data_type_utils import get_torch_dtype


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_path)

if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -118,7 +107,6 @@ def train(
None for fine tuning
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

logger = logging.get_logger("sft_trainer")

Expand All @@ -132,11 +120,6 @@ def train(
):
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

# make sure to unset FSDP args when running on single gpu
if not run_distributed:
train_args.fsdp = ""
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved
train_args.fsdp_config = {"xla": False}

task_type = "CAUSAL_LM"
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand All @@ -147,8 +130,6 @@ def train(

peft_config = get_hf_peft_config(task_type, peft_config)

model.gradient_checkpointing_enable()

# TODO: Move these to a config as well
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
Expand Down Expand Up @@ -239,8 +220,7 @@ def train(

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]
callbacks = [aim_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
Expand Down Expand Up @@ -281,7 +261,7 @@ def train(
peft_config=peft_config,
)

if run_distributed and peft_config is not None:
if trainer.is_fsdp_enabled and peft_config is not None:
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
Expand Down
Loading