Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRL SFTTrainer Examples #2211

Merged
merged 24 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sparseml/transformers/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ A recipe can be run stage-by-stage by setting `run_stages` to `True` or calling
a `run_type` attribute set to either `oneshot` or `train` when running in sequential
mode.

See [example_alternating_recipe.yaml](example_alternating_recipe.yaml) for an example
See [example_alternating_recipe.yaml](examples/example_alternating_recipe.yaml) for an example
of a staged recipe for Llama.

### Python Example
Expand All @@ -147,7 +147,7 @@ dataset_name = "open_platypus"
concatenate_data = False
run_stages=True
output_dir = "./output_finetune_multi"
recipe = "example_alternating_recipe.yaml"
recipe = "examples/example_alternating_recipe.yaml"
num_train_epochs=1
overwrite_output_dir = True
splits = {
Expand Down
6 changes: 6 additions & 0 deletions src/sparseml/transformers/finetune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@

# flake8: noqa

from .data import DataTrainingArguments, TextGenerationDataset
from .model_args import ModelArguments
from .session_mixin import SessionManagerMixIn
from .sft_trainer import SFTTrainer
from .text_generation import apply, compress, eval, oneshot, train
from .trainer import Trainer
from .training_args import TrainingArguments
1 change: 1 addition & 0 deletions src/sparseml/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .c4 import C4Dataset
from .cnn_dailymail import CNNDailyMailDataset
from .custom import CustomDataset
from .data_args import DataTrainingArguments
from .evolcodealpaca import EvolCodeAlpacaDataset
from .gsm8k import GSM8KDataset
from .open_platypus import OpenPlatypusDataset
Expand Down
5 changes: 4 additions & 1 deletion src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset:
**self.raw_kwargs,
)

def tokenize_and_process(self, raw_dataset: Dataset) -> Dataset:
def tokenize_and_process(self, raw_dataset: Optional[Dataset] = None) -> Dataset:
"""
Sets up the raw dataset for finetuning, performs tokenization, concatenates
entries to max sequence length if desired, and adds labels to each entry
Expand Down Expand Up @@ -168,6 +168,9 @@ def label_fn(data):
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding
return data

if raw_dataset is None:
raw_dataset = self.get_raw_dataset()

dataset = self.map(
raw_dataset,
function=tokenize_fn,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import DefaultDataCollator

from sparseml.transformers import (
DataTrainingArguments,
SFTTrainer,
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
TextGenerationDataset,
TrainingArguments,
)


model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
teacher_path = "zoo:llama2-7b-gsm8k_llama2_pretrain-base"
output_dir = "./output_trl_sft_test_7b_gsm8k"

model = SparseAutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto"
)
teacher = SparseAutoModelForCausalLM.from_pretrained(
teacher_path, torch_dtype="auto", device_map="auto"
)

tokenizer = SparseAutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train",
tokenizer=tokenizer,
)
train_dataset = dataset_manager.tokenize_and_process()
print(f"--> Training Set Length = {len(train_dataset)}")

# recipe for maintaining model sparsity during finetuning
recipe = """
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight',
're:.*o_proj.weight', 're:.*gate_proj.weight', 're:.*up_proj.weight',
're:.*down_proj.weight']
start: 0
OutputDistillationModifier:
targets: ['re:model.layers.\\d+$']
comparison: "square_head"
start: 0
orig_scale: 1.0
distill_scale: 1.0
"""

data_collator = DefaultDataCollator()
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
bf16=True,
)
trainer = SFTTrainer(
model=model,
teacher=teacher,
tokenizer=tokenizer,
recipe=recipe,
train_dataset=train_dataset,
data_collator=data_collator,
args=training_args,
data_args=data_args,
max_seq_length=data_args.max_seq_length,
packing=True,
)
trainer.train()
trainer.save_model(output_dir=trainer.args.output_dir)
78 changes: 78 additions & 0 deletions src/sparseml/transformers/finetune/examples/test_trl_sft_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset

from sparseml.transformers import (
SFTTrainer,
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
TrainingArguments,
)
from trl import DataCollatorForCompletionOnlyLM


model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
model = SparseAutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto"
)
tokenizer = SparseAutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# recipe for maintaining model sparsity during finetuning
recipe = """
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight',
're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight',
're:.*down_proj.weight']
start: 0
"""

# Load gsm8k using TRL dataset tools
dataset = load_dataset("gsm8k", "main", split="train")


def formatting_prompts_func(example):
output_texts = []
for i in range(len(example["question"])):
text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts


response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it important at all that the TrainingArguments comes from SparseML?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things won't work if the native transformers TrainingArguments is used: no support for recipe overrides, no compressed save, no multistage training runs. The mix-in uses these params, so if we wanted to support the non-sparseml TrainingArguments we would have to check each time we reference them. I don't think its worth the extra lines personally


trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
recipe=recipe,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
args=training_args,
max_seq_length=512,
)
trainer.train()
trainer.save_model(output_dir=trainer.args.output_dir)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 20 additions & 31 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
from sparseml.core.session import callbacks
from sparseml.pytorch.model_load.helpers import (
RECIPE_FILE_NAME,
get_session_model,
reload_model_state,
)
from sparseml.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model
from sparseml.pytorch.utils import LoggerManager, ModuleSparsificationInfo
from sparseml.transformers.finetune.callbacks import (
DisableHalfPrecisionCallback,
Expand All @@ -49,33 +45,34 @@

_LOGGER = logging.getLogger(__name__)
TRAINER_STATE_NAME = "trainer_state.json"
METADATA_ARGS = [
"per_device_train_batch_size",
"per_device_eval_batch_size",
"max_seq_length",
"save_safetensors",
"fp16",
]


class SessionManagerMixIn:
"""
Mix-In class to extend the Hugging Face Trainer class to support SparseML recipes
for one-shot and finetuning flows.

:param model_state_path: path to Pytorch model checkpoint or saved model
:param recipe: path to recipe file to apply during training
:param recipe_args: additional kwargs to use for evaluating recipe
:param metadata_args: additional kwargs for configuring training
:param data_args: kwargs for configuring dataset loading
:param teacher: optional teacher model to use for distillation
"""

def __init__(
self,
model_state_path: str,
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
metadata_args: Optional[List[str]] = None,
data_args: Optional["DataTrainingArguments"] = None, # noqa: F821
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):
# instantiate necessary state, like managers, so we can override args
self.model_state_path = str(model_state_path)
self.recipe = recipe
self.recipe_args = recipe_args
self.teacher = teacher
Expand All @@ -84,11 +81,11 @@ def __init__(
training_args = kwargs.get("args")
self.metadata = (
self._extract_metadata(
metadata_args=metadata_args,
metadata_args=METADATA_ARGS,
training_args_dict=training_args.to_dict(),
data_args_dict=asdict(data_args) if data_args else {},
)
if training_args and metadata_args
if training_args and METADATA_ARGS
else None
)

Expand All @@ -98,6 +95,7 @@ def __init__(

# call Trainer initialization
super().__init__(**kwargs)
self.accelerator.wait_for_everyone()

# setup callbacks and loss
self.optim_callbacks = TrainingLoopCallbacks(self)
Expand All @@ -107,14 +105,17 @@ def __init__(
self.criterion = torch.nn.CrossEntropyLoss()

model_signature = inspect.signature(self.model.forward)
self._model_signature_columns = list(model_signature.parameters.keys())
self._signature_columns = list(model_signature.parameters.keys())
Satrat marked this conversation as resolved.
Show resolved Hide resolved

if self.teacher is not None and teacher not in ("disable", "self"):
teacher_signature = inspect.signature(self.teacher.forward)
self._teacher_signature_columns = list(teacher_signature.parameters.keys())
else:
self._teacher_signature_columns = None

if self.is_fsdp_enabled:
self._prepare_model_for_fsdp()

def initialize_session(
self,
epoch: float,
Expand All @@ -134,7 +135,6 @@ def initialize_session(
if session.lifecycle.initialized_ or session.lifecycle.finalized:
return False

orig_state_dict = self.model.state_dict()
train_data = self.get_train_dataloader()

self.accelerator.wait_for_everyone()
Expand All @@ -154,17 +154,7 @@ def initialize_session(
)
self.accelerator.wait_for_everyone()
model = get_session_model()
self.model = model

# reload the state dict for the model now that architecture matches expected
# TODO: what if there is a quant modifier in the original recipe and we want to
# continue adjusting its zero point and range?
load_path = checkpoint or self.model_state_path
if reload_model_state(model, load_path, orig_state_dict):
_LOGGER.info(
"Reloaded model state after SparseML recipe structure modifications "
f"from {load_path}"
)
self.model_wrapped = self.model = model

if self.recipe is None:
_LOGGER.warning(
Expand Down Expand Up @@ -298,7 +288,7 @@ def compute_loss(
self._check_super_defined("compute_loss")

# TODO: do we need these model signature columns?
inputs = {k: inputs[k] for k in inputs if k in self._model_signature_columns}
inputs = {k: inputs[k] for k in inputs if k in self._signature_columns}
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)

# take the mean across multiple GPUs
Expand Down Expand Up @@ -472,9 +462,8 @@ def save_model(
)

self.save_state()
self.tokenizer.save_pretrained(output_dir)
if not _is_oneshot: # optimizer/scheduler not relevant to one-shot
self.save_optimizer_and_scheduler(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

if not self.recipe:
return
Expand All @@ -497,7 +486,7 @@ def log_model_sparsification(self):
sparsification_info = ModuleSparsificationInfo(self.model)

_LOGGER.info(
f"Sparsification info for {self.model_state_path}: "
f"Sparsification info for {type(self.model).__name__}: "
f"{sparsification_info.params_total} total params. "
f"Of those there are {sparsification_info.params_prunable_total} prunable "
f"params which have {sparsification_info.params_prunable_sparse_percent} "
Expand Down
Loading
Loading