Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT]: Add checkpointing for FSDP2 #11893

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
98 changes: 77 additions & 21 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,46 @@
# limitations under the License.

import fiddle as fdl
import torch.nn as nn
from lightning.pytorch.loggers import WandbLogger
from torch.nn import Module

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import NeMoLogger
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


def is_mlp_layer(module):
"""
Check if a module is an MLP layer.
"""
if isinstance(module, nn.Sequential):
for submodule in module:
if not isinstance(submodule, (nn.Linear, nn.ReLU, nn.Dropout, nn.BatchNorm1d)):
return False
return True
elif isinstance(module, nn.Linear):
return True
return False


def mlp_activation_checkpointing_policy(module: Module, recurse: bool, nonwrapped_numel: int) -> bool:
"""
Custom activation checkpointing policy for FSDPStrategy.
Returns True for MLP layers.
Args:
module (Module): The module being inspected.
recurse (bool): Whether to recurse into submodules.
nonwrapped_numel (int): The number of elements in non-wrapped parameters.
Returns:
bool: True if the module should be wrapped for activation checkpointing.
"""
# Check if the current module is an MLP layer
breakpoint()
return is_mlp_layer(module)


def make_squad_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

Expand All @@ -37,25 +69,25 @@ def formatting_prompts_func(examples):
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
output = examples["answers"]["text"]
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
# 'input_ids' is a list, we want to remove EOS_TOKEN from input_ids and the first token from
# labels to align the two:
ans['labels'] = list(ans['input_ids'][1:])
ans['input_ids'] = ans['input_ids'][:-1]
ans['attention_mask'] = ans['attention_mask'][:-1]
ans["labels"] = list(ans["input_ids"][1:])
ans["input_ids"] = ans["input_ids"][:-1]
ans["attention_mask"] = ans["attention_mask"][:-1]
return ans

tokenizer = getattr(tokenizer, 'tokenizer', tokenizer)
tokenizer = getattr(tokenizer, "tokenizer", tokenizer)
datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train[:100]", pad_token_id=tokenizer.eos_token_id)
datamodule.map(
formatting_prompts_func,
batched=False,
batch_size=2,
remove_columns=["id", "title", "context", "question", 'answers'],
remove_columns=["id", "title", "context", "question", "answers"],
)
return datamodule

Expand All @@ -64,28 +96,47 @@ def main():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
parser.add_argument('--ckpt-folder', type=str, default=None)
parser.add_argument("--model", default="meta-llama/Llama-3.2-1B")
parser.add_argument("--strategy", type=str, default="auto", choices=["auto", "ddp", "fsdp"])
parser.add_argument("--strategy", type=str, default="auto", choices=["auto", "ddp", "fsdp", "fsdp2"])
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--accelerator", default="gpu", choices=["gpu"])
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--wandb-project", type=str, default=None)
parser.add_argument("--use-torch-jit", action="store_true")
parser.add_argument("--ckpt-folder", type=str, default=None)
parser.add_argument(
"--ckpting-layers",
type=str,
nargs="+", # Accepts one or more strings as a list
default=None,
help="Checkpointing layers as a list of strings or a single string",
)

args = parser.parse_args()

wandb = None
if args.wandb_project is not None:
model = '_'.join(args.model.split('/')[-2:])
model = "_".join(args.model.split("/")[-2:])
wandb = WandbLogger(
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
name=f"{model}_dev{args.devices}_strat_{args.strategy}",
)
grad_clip = 0.5
if args.strategy == 'fsdp':
if args.strategy == "fsdp":
# See:
# https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
import torch
from torch.nn import Linear

from nemo.lightning.pytorch.strategies import FSDPStrategy

args.strategy = FSDPStrategy(
# activation_checkpointing_policy=mlp_activation_checkpointing_policy
activation_checkpointing_policy={Linear}
)

use_dist_samp = False

import tempfile
Expand All @@ -98,11 +149,13 @@ def main():

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False)
jit_config = JitConfig(use_torch=True, torch_kwargs={"dynamic": True}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

model = llm.HFAutoModelForCausalLM(args.model)

llm.api.finetune(
model=llm.HFAutoModelForCausalLM(args.model),
model=model,
data=make_squad_hf_dataset(tokenizer.tokenizer),
trainer=nl.Trainer(
devices=args.devices,
Expand All @@ -122,11 +175,14 @@ def main():
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=NeMoLogger(log_dir=args.ckpt_folder, use_datetime_version=False),
peft=llm.peft.LoRA(
target_modules=['*_proj'],
target_modules=["*_proj"],
dim=8,
),
)
import torch

print(torch.cuda.memory_summary())


if __name__ == '__main__':
if __name__ == "__main__":
main()
96 changes: 80 additions & 16 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,94 @@ def squad(tokenizer) -> pl.LightningDataModule:
)


DATA_PATH = "/workspace/squad/"


def make_squad_hf_dataset(data_path, tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]["text"]
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
ans["labels"] = ans["input_ids"]
return ans

tokenizer = getattr(tokenizer, "tokenizer", tokenizer)
datamodule = llm.HFDatasetDataModule(
data_path,
split="train[:100]",
pad_token_id=tokenizer.eos_token_id,
seq_length=512,
micro_batch_size=2,
global_batch_size=128,
)

datamodule.map(
formatting_prompts_func,
batched=False,
batch_size=2,
remove_columns=["id", "title", "context", "question", "answers"],
)

return datamodule


def main():
"""Example script to run SFT with a HF transformers-instantiated model on squad."""
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--model-accelerator', default=None, choices=['te'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument("--fp8-autocast", default=False, action='store_true')
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--model-save-path', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
parser.add_argument("--model", default="meta-llama/Llama-3.2-1B")
parser.add_argument("--strategy", type=str, default="auto", choices=["auto", "ddp", "fsdp", "fsdp2"])
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--accelerator", default="gpu", choices=["gpu"])
parser.add_argument("--model-accelerator", default=None, choices=["te"])
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--fp8-autocast", default=False, action="store_true")
parser.add_argument("--wandb-project", type=str, default=None)
parser.add_argument("--model-save-path", type=str, default=None)
parser.add_argument("--use-torch-jit", action="store_true")
args = parser.parse_args()

wandb = None
if args.wandb_project is not None:
model = '_'.join(args.model.split('/')[-2:])
model = "_".join(args.model.split("/")[-2:])
wandb = WandbLogger(
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
name=f"{model}_dev{args.devices}_strat_{args.strategy}",
)
grad_clip = 0.5
if args.strategy == 'fsdp':
if args.strategy == "fsdp":
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None

if args.strategy == "fsdp2":
from nemo.lightning.pytorch.strategies import FSDP2Strategy

grad_clip = None
from transformers.models.llama.modeling_llama import LlamaMLP

args.strategy = FSDP2Strategy(
data_parallel_size=args.devices,
tensor_parallel_size=1,
activation_checkpointing_policy={LlamaMLP},
)

use_dist_samp = False

model_accelerator = None
Expand All @@ -105,12 +165,12 @@ def main():

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)
jit_config = JitConfig(use_torch=True, torch_kwargs={"dynamic": False}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

llm.api.finetune(
model=model,
data=squad(tokenizer),
data=make_squad_hf_dataset(DATA_PATH, tokenizer),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
Expand All @@ -130,6 +190,10 @@ def main():
log=None,
)

import torch

print(torch.cuda.memory_summary())

if args.model_accelerator:
if args.model_accelerator == "te":
te_acc = is_te_accelerated(model.model)
Expand All @@ -140,5 +204,5 @@ def main():
model.save_pretrained(args.model_save_path)


if __name__ == '__main__':
if __name__ == "__main__":
main()
Loading
Loading