diff --git a/src/sparseml/transformers/finetune/example_alternating_recipe.yaml b/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml similarity index 100% rename from src/sparseml/transformers/finetune/example_alternating_recipe.yaml rename to integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md new file mode 100644 index 00000000000..25c3b54976b --- /dev/null +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md @@ -0,0 +1,46 @@ + + +# Sparse Finetuning with TRL's SFTTrainer + +The `SessionManagerMixin` can be added to other Trainer classes that inherit from +[Hugging Face's Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer). + +For example, we can add SparseML support to TRL's SFTTrainer like so: + +```python +from trl import SFTTrainer as TRLSFTTrainer + +class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): + ... +``` + +The new `SFTTrainer` class can now apply SparseML recipes and modifiers during +supervised finetuning, will full support for all of the original TRL features. The full +class is defined in [sft_trainer.py](sft_trainer.py) and requires very minimal +additional code: just a dataset load override to support passing in tokenized datasets +to the Trainer. + +### Examples + +[ex_trl_sft_data.py](ex_trl_sft_data.py): finetunes a 50% sparse Llama-7b model, +using TRL's dataset preprocessing. Sparsity is maintained throughout training by +applying a `ConstantPruningModifier` recipe to the `SFTTrainer` + +[ex_trl_distillation.py](ex_trl_distillation.py): finetunes a 50% sparse Llama-7b +model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained +throughout training with a `ConstantPruningModifier` and layer-wise knowledge +distillation is handled by the `OutputDistillationModifier` \ No newline at end of file diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_constant.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_constant.py new file mode 100644 index 00000000000..bc87338e191 --- /dev/null +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_constant.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from sparseml.transformers import ( + SFTTrainer, + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TrainingArguments, +) +from trl import DataCollatorForCompletionOnlyLM + + +model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" +output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) +tokenizer = SparseAutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token + +# recipe for maintaining model sparsity during finetuning +recipe = """ +test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] + start: 0 +""" + +# Load gsm8k using TRL dataset tools +dataset = load_dataset("gsm8k", "main", split="train") + + +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example["question"])): + text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}" + output_texts.append(text) + return output_texts + + +response_template = "Answer:" +collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + +training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, +) + +trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + recipe=recipe, + train_dataset=dataset, + formatting_func=formatting_prompts_func, + data_collator=collator, + args=training_args, + max_seq_length=512, +) +trainer.train() +trainer.save_model() diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_distillation.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_distillation.py new file mode 100644 index 00000000000..c5a2ac07c1b --- /dev/null +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_distillation.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import DefaultDataCollator + +from sparseml.transformers import ( + DataTrainingArguments, + SFTTrainer, + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TextGenerationDataset, + TrainingArguments, +) + + +model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" +teacher_path = "zoo:llama2-7b-gsm8k_llama2_pretrain-base" +output_dir = "./output_trl_sft_test_7b_gsm8k" + +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) +teacher = SparseAutoModelForCausalLM.from_pretrained( + teacher_path, torch_dtype="auto", device_map="auto" +) + +tokenizer = SparseAutoTokenizer.from_pretrained(model_path) + +# Load gsm8k using SparseML dataset tools +data_args = DataTrainingArguments( + dataset="gsm8k", dataset_config_name="main", max_seq_length=512 +) +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train", + tokenizer=tokenizer, +) +train_dataset = dataset_manager.tokenize_and_process() +print(f"--> Training Set Length = {len(train_dataset)}") + +# recipe for maintaining model sparsity during finetuning +recipe = """ +test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight', 're:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] + start: 0 + OutputDistillationModifier: + targets: ['re:model.layers.\\d+$'] + comparison: "square_head" + start: 0 + orig_scale: 1.0 + distill_scale: 1.0 +""" + +data_collator = DefaultDataCollator() +training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, + bf16=True, +) +trainer = SFTTrainer( + model=model, + teacher=teacher, + tokenizer=tokenizer, + recipe=recipe, + train_dataset=train_dataset, + data_collator=data_collator, + args=training_args, + data_args=data_args, + max_seq_length=data_args.max_seq_length, + packing=True, +) +trainer.train() +trainer.save_model() diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/sft_trainer.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/sft_trainer.py new file mode 100644 index 00000000000..bc455b56aa1 --- /dev/null +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/sft_trainer.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn +from trl import SFTTrainer as TRLSFTTrainer + + +__all__ = ["SFTTrainer"] + + +class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): + def _prepare_dataset(self, dataset, *args, **kwargs): + if "input_ids" in dataset.column_names: + # dataset is already tokenized, skip preprocessing + return dataset + + return super()._prepare_dataset(dataset, *args, **kwargs) diff --git a/src/sparseml/transformers/finetune/README.md b/src/sparseml/transformers/finetune/README.md index dc3a61e9ba3..aaee586d671 100644 --- a/src/sparseml/transformers/finetune/README.md +++ b/src/sparseml/transformers/finetune/README.md @@ -132,7 +132,7 @@ A recipe can be run stage-by-stage by setting `run_stages` to `True` or calling a `run_type` attribute set to either `oneshot` or `train` when running in sequential mode. -See [example_alternating_recipe.yaml](example_alternating_recipe.yaml) for an example +See [example_alternating_recipe.yaml](../../../../integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml) for an example of a staged recipe for Llama. ### Python Example diff --git a/src/sparseml/transformers/finetune/__init__.py b/src/sparseml/transformers/finetune/__init__.py index 0995d54d97c..6329b5a4692 100644 --- a/src/sparseml/transformers/finetune/__init__.py +++ b/src/sparseml/transformers/finetune/__init__.py @@ -14,4 +14,9 @@ # flake8: noqa +from .data import DataTrainingArguments, TextGenerationDataset +from .model_args import ModelArguments +from .session_mixin import SessionManagerMixIn from .text_generation import apply, compress, eval, oneshot, train +from .trainer import Trainer +from .training_args import TrainingArguments diff --git a/src/sparseml/transformers/finetune/data/__init__.py b/src/sparseml/transformers/finetune/data/__init__.py index c9c7cb9e509..f6ac4bbe1c6 100644 --- a/src/sparseml/transformers/finetune/data/__init__.py +++ b/src/sparseml/transformers/finetune/data/__init__.py @@ -18,6 +18,7 @@ from .c4 import C4Dataset from .cnn_dailymail import CNNDailyMailDataset from .custom import CustomDataset +from .data_args import DataTrainingArguments from .evolcodealpaca import EvolCodeAlpacaDataset from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 6f34bc352d5..354dfcccbe1 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -111,7 +111,7 @@ def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: **self.raw_kwargs, ) - def tokenize_and_process(self, raw_dataset: Dataset) -> Dataset: + def tokenize_and_process(self, raw_dataset: Optional[Dataset] = None) -> Dataset: """ Sets up the raw dataset for finetuning, performs tokenization, concatenates entries to max sequence length if desired, and adds labels to each entry @@ -168,6 +168,9 @@ def label_fn(data): data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding return data + if raw_dataset is None: + raw_dataset = self.get_raw_dataset() + dataset = self.map( raw_dataset, function=tokenize_fn, diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index a696fc02a6c..80c0c6432b2 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -28,11 +28,7 @@ import sparseml.core.session as session_manager from sparseml.core.framework import Framework from sparseml.core.session import callbacks -from sparseml.pytorch.model_load.helpers import ( - RECIPE_FILE_NAME, - get_session_model, - reload_model_state, -) +from sparseml.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model from sparseml.pytorch.utils import LoggerManager, ModuleSparsificationInfo from sparseml.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, @@ -49,6 +45,13 @@ _LOGGER = logging.getLogger(__name__) TRAINER_STATE_NAME = "trainer_state.json" +METADATA_ARGS = [ + "per_device_train_batch_size", + "per_device_eval_batch_size", + "max_seq_length", + "save_safetensors", + "fp16", +] class SessionManagerMixIn: @@ -56,26 +59,20 @@ class SessionManagerMixIn: Mix-In class to extend the Hugging Face Trainer class to support SparseML recipes for one-shot and finetuning flows. - :param model_state_path: path to Pytorch model checkpoint or saved model :param recipe: path to recipe file to apply during training :param recipe_args: additional kwargs to use for evaluating recipe - :param metadata_args: additional kwargs for configuring training :param data_args: kwargs for configuring dataset loading :param teacher: optional teacher model to use for distillation """ def __init__( self, - model_state_path: str, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, - metadata_args: Optional[List[str]] = None, data_args: Optional["DataTrainingArguments"] = None, # noqa: F821 teacher: Optional[Union[Module, str]] = None, **kwargs, ): - # instantiate necessary state, like managers, so we can override args - self.model_state_path = str(model_state_path) self.recipe = recipe self.recipe_args = recipe_args self.teacher = teacher @@ -84,11 +81,11 @@ def __init__( training_args = kwargs.get("args") self.metadata = ( self._extract_metadata( - metadata_args=metadata_args, + metadata_args=METADATA_ARGS, training_args_dict=training_args.to_dict(), data_args_dict=asdict(data_args) if data_args else {}, ) - if training_args and metadata_args + if training_args and METADATA_ARGS else None ) @@ -98,6 +95,7 @@ def __init__( # call Trainer initialization super().__init__(**kwargs) + self.accelerator.wait_for_everyone() # setup callbacks and loss self.optim_callbacks = TrainingLoopCallbacks(self) @@ -107,7 +105,7 @@ def __init__( self.criterion = torch.nn.CrossEntropyLoss() model_signature = inspect.signature(self.model.forward) - self._model_signature_columns = list(model_signature.parameters.keys()) + self._signature_columns = list(model_signature.parameters.keys()) if self.teacher is not None and teacher not in ("disable", "self"): teacher_signature = inspect.signature(self.teacher.forward) @@ -115,6 +113,9 @@ def __init__( else: self._teacher_signature_columns = None + if self.is_fsdp_enabled: + self._prepare_model_for_fsdp() + def initialize_session( self, epoch: float, @@ -134,7 +135,6 @@ def initialize_session( if session.lifecycle.initialized_ or session.lifecycle.finalized: return False - orig_state_dict = self.model.state_dict() train_data = self.get_train_dataloader() self.accelerator.wait_for_everyone() @@ -154,17 +154,7 @@ def initialize_session( ) self.accelerator.wait_for_everyone() model = get_session_model() - self.model = model - - # reload the state dict for the model now that architecture matches expected - # TODO: what if there is a quant modifier in the original recipe and we want to - # continue adjusting its zero point and range? - load_path = checkpoint or self.model_state_path - if reload_model_state(model, load_path, orig_state_dict): - _LOGGER.info( - "Reloaded model state after SparseML recipe structure modifications " - f"from {load_path}" - ) + self.model_wrapped = self.model = model if self.recipe is None: _LOGGER.warning( @@ -298,7 +288,7 @@ def compute_loss( self._check_super_defined("compute_loss") # TODO: do we need these model signature columns? - inputs = {k: inputs[k] for k in inputs if k in self._model_signature_columns} + inputs = {k: inputs[k] for k in inputs if k in self._signature_columns} loss = super().compute_loss(model, inputs, return_outputs=return_outputs) # take the mean across multiple GPUs @@ -464,9 +454,8 @@ def save_model( ) self.save_state() - self.tokenizer.save_pretrained(output_dir) - if not _is_oneshot: # optimizer/scheduler not relevant to one-shot - self.save_optimizer_and_scheduler(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) if not self.recipe: return @@ -508,7 +497,7 @@ def log_model_sparsification(self): sparsification_info = ModuleSparsificationInfo(self.model) _LOGGER.info( - f"Sparsification info for {self.model_state_path}: " + f"Sparsification info for {type(self.model).__name__}: " f"{sparsification_info.params_total} total params. " ) _LOGGER.info( diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index a25778aa5fa..fbc1bcb146c 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -49,14 +49,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -metadata_args = [ - "per_device_train_batch_size", - "per_device_eval_batch_size", - "max_seq_length", - "save_safetensors", - "fp16", -] - def train(**kwargs): """ @@ -134,10 +126,6 @@ def parse_args(**kwargs): arg_dict[key] = value training_args.recipe_args = arg_dict - # when set to true in FSDP mode this causes issues, the model arguments show up - # as *args and **kwargs so all columns get removed - training_args.remove_unused_columns = False - return model_args, data_args, training_args @@ -331,9 +319,7 @@ def main( trainer = Trainer( model_init=get_session_model, teacher=teacher, - model_state_path=model_path, recipe=training_args.recipe, - metadata_args=metadata_args, recipe_args=training_args.recipe_args, args=training_args, data_args=data_args, @@ -342,8 +328,6 @@ def main( tokenizer=tokenizer, data_collator=data_collator, ) - if trainer.is_fsdp_enabled: - trainer._prepare_model_for_fsdp() stage_runner.trainer = trainer # alternating Training/One-shot diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index 36a850f251b..d918f50880b 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -12,102 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.nn import Module from transformers import Trainer as HFTransformersTrainer -from transformers.trainer_pt_utils import reissue_pt_warnings from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn __all__ = ["Trainer"] -TRAINER_STATE_NAME = "trainer_state.json" -OPTIMIZER_NAME = "optimizer.pt" -SCHEDULER_NAME = "scheduler.pt" -SCALER_NAME = "scaler.pt" - class Trainer(SessionManagerMixIn, HFTransformersTrainer): - """ - Training implementation for running sparsification recipes with HF Trainer. - - :param model: the model to use with the trainer and apply sparsification to - :param model_state_path: the state path to the model, - used to load config and tokenizer settings - :param recipe: the recipe, if any, to apply to the modle and training - process - :param recipe_args: A json string, csv key=value string, or dictionary containing - arguments to override the root arguments within the recipe such as - learning rate or num epochs - :param teacher: teacher model for distillation. Set to 'self' to distill - from the loaded model or 'disable' to turn of distillation - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_state_path: str, - model: Optional[Module] = None, - model_init: Optional[Callable] = None, - recipe: Optional[str] = None, - recipe_args: Optional[Union[Dict[str, Any], str]] = None, - teacher: Optional[Union[Module, str]] = None, - **kwargs, - ): - super().__init__( - model=model, - model_init=model_init, - model_state_path=model_state_path, - recipe=recipe, - recipe_args=recipe_args, - teacher=teacher, - **kwargs, - ) - - def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): - """ - Save optimizer, scheduler and scaler - - :param output_dir: The output model directory to save the above - """ - if output_dir is None: - output_dir = self.args.output_dir - - if self.is_world_process_zero(): - if self.optimizer is not None: - torch.save( - self.optimizer.state_dict(), - os.path.join(output_dir, "optimizer.pt"), - ) - with warnings.catch_warnings(record=True) as caught_warnings: - if self.lr_scheduler is not None: - torch.save( - self.lr_scheduler.state_dict(), - os.path.join(output_dir, "scheduler.pt"), - ) - reissue_pt_warnings(caught_warnings) - - def _save_checkpoint(self, model, trial, metrics=None): - # Call into the save checkpoint by HF Transformers, which saves the - # best metric if required - super()._save_checkpoint(model, trial, metrics=metrics) - if ( - self.args.metric_for_best_model is None - or self.args.best_model_after_epoch is None - ): - return - - if self.state.epoch <= self.args.best_model_after_epoch: - self.state.best_metric = None - self.state.best_model_checkpoint = None - - def _dummy_lr_scheduler(self): - return torch.optim.lr_scheduler.MultiplicativeLR( - self.optimizer, - lambda _: 1.0, - ) + pass diff --git a/src/sparseml/transformers/finetune/training_args.py b/src/sparseml/transformers/finetune/training_args.py index 083fb5c5e2b..49e4120572c 100644 --- a/src/sparseml/transformers/finetune/training_args.py +++ b/src/sparseml/transformers/finetune/training_args.py @@ -32,10 +32,6 @@ class TrainingArguments(HFTrainingArgs): arguments """ - best_model_after_epoch: int = field( - default=None, - metadata={"help": "Epoch after which best model will be saved."}, - ) recipe: Optional[str] = field( default=None, metadata={ diff --git a/tests/sparseml/transformers/finetune/test_session_mixin.py b/tests/sparseml/transformers/finetune/test_session_mixin.py index 24af397920b..cc74de299e0 100644 --- a/tests/sparseml/transformers/finetune/test_session_mixin.py +++ b/tests/sparseml/transformers/finetune/test_session_mixin.py @@ -27,7 +27,6 @@ class MixInTest(SessionManagerMixIn, Trainer): def __init__( self, model: Module, - model_state_path: str, recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, @@ -35,7 +34,6 @@ def __init__( ): super().__init__( model=model, - model_state_path=model_state_path, recipe=recipe, recipe_args=recipe_args, teacher=teacher, @@ -48,9 +46,7 @@ def test_mixin_init(): model = AutoModelForCausalLM.from_pretrained(model_state_path) recipe = "tests/sparseml/transformers/finetune/test_quantization.yaml" - session_mixin = MixInTest( - model=model, model_state_path=model_state_path, recipe=recipe - ) + session_mixin = MixInTest(model=model, recipe=recipe) assert isinstance(session_mixin, SessionManagerMixIn) assert isinstance(session_mixin, Trainer) assert session_mixin.recipe == recipe @@ -67,7 +63,6 @@ def mixin_trainer(): return MixInTest( model=model, - model_state_path=model_state_path, recipe=recipe, train_dataset=train_dataset, eval_dataset=eval_dataset,