From f013c2c3033e4d562f6311b2496b51df53ab4697 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 5 Mar 2024 21:52:19 +0000 Subject: [PATCH 01/17] WIP sft mixin --- .../transformers/finetune/sft_trainer.py | 99 +++++++++++++++++++ .../transformers/finetune/text_generation.py | 15 ++- .../transformers/sparsification/trainer.py | 6 +- 3 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 src/sparseml/transformers/finetune/sft_trainer.py diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/sft_trainer.py new file mode 100644 index 00000000000..475540311c7 --- /dev/null +++ b/src/sparseml/transformers/finetune/sft_trainer.py @@ -0,0 +1,99 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from typing import Any, Callable, Dict, Optional, Union + +import torch +from torch.nn import Module +from trl import SFTTrainer as TRLSFTTrainer +from peft import PeftConfig +from transformers.trainer_pt_utils import reissue_pt_warnings + +from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn + + +__all__ = ["SFTTrainer"] + +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" + + +class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): + """ + Training implementation for running sparsification recipes with HF Trainer. + + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation + :param kwargs: key word arguments passed to the parent class + """ + + def __init__( + self, + model_state_path: str, + model: Optional[Module] = None, + model_init: Optional[Callable] = None, + recipe: Optional[str] = None, + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Union[Module, str]] = None, + peft_config: Optional[PeftConfig] = None, + **kwargs, + ): + super().__init__( + model=model, + model_init=model_init, + model_state_path=model_state_path, + recipe=recipe, + recipe_args=recipe_args, + teacher=teacher, + **kwargs, + ) + + def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): + """ + Save optimizer, scheduler and scaler + + :param output_dir: The output model directory to save the above + """ + if output_dir is None: + output_dir = self.args.output_dir + + if self.is_world_process_zero(): + if self.optimizer is not None: + torch.save( + self.optimizer.state_dict(), + os.path.join(output_dir, "optimizer.pt"), + ) + with warnings.catch_warnings(record=True) as caught_warnings: + if self.lr_scheduler is not None: + torch.save( + self.lr_scheduler.state_dict(), + os.path.join(output_dir, "scheduler.pt"), + ) + reissue_pt_warnings(caught_warnings) + if self.use_cuda_amp: + torch.save( + self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") + ) \ No newline at end of file diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index 8a8bd1d12b7..e084cd24f7f 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -39,6 +39,7 @@ from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.runner import StageRunner from sparseml.transformers.finetune.trainer import Trainer +from sparseml.transformers.finetune.sft_trainer import SFTTrainer from sparseml.transformers.finetune.training_args import TrainingArguments from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src from sparseml.transformers.utils.helpers import detect_last_checkpoint @@ -331,7 +332,17 @@ def main( # Initialize our Trainer data_collator = DefaultDataCollator() - trainer = Trainer( + + from peft import LoraConfig + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( model_init=get_session_model, teacher=teacher, model_state_path=model_path, @@ -344,6 +355,8 @@ def main( eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, + peft_config=lora_config, + dataset_text_field="text" ) if trainer.is_fsdp_enabled: trainer._prepare_model_for_fsdp() diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index bc45bec6d97..61393d31deb 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -35,7 +35,7 @@ from transformers.integrations import TensorBoardCallback from transformers.trainer_callback import TrainerState from transformers.trainer_pt_utils import reissue_pt_warnings -from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint +#from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint from sparseml.pytorch.model_load.helpers import log_model_load from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer @@ -894,8 +894,8 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): if output_dir is None: output_dir = self.args.output_dir - if self.sharded_ddp == ShardedDDPOption.SIMPLE and self.optimizer is not None: - self.optimizer.consolidate_state_dict() + #if self.sharded_ddp == ShardedDDPOption.SIMPLE and self.optimizer is not None: + # self.optimizer.consolidate_state_dict() if self.is_world_process_zero(): if self.optimizer is not None: From 9986f34cf04bd931a22c19ad9f3b75ebb9b64227 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 5 Mar 2024 22:00:17 +0000 Subject: [PATCH 02/17] its running at least --- .../transformers/finetune/sft_trainer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/sft_trainer.py index 475540311c7..5e6626f7362 100644 --- a/src/sparseml/transformers/finetune/sft_trainer.py +++ b/src/sparseml/transformers/finetune/sft_trainer.py @@ -96,4 +96,20 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): if self.use_cuda_amp: torch.save( self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") - ) \ No newline at end of file + ) + + def _prepare_dataset( + self, + dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + remove_unused_columns=True, + append_concat_token=True, + add_special_tokens=True, + ): + return dataset \ No newline at end of file From 51dd1091037e1aed5754b42d813da4cb564591a5 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 5 Mar 2024 22:05:20 +0000 Subject: [PATCH 03/17] clean up --- src/sparseml/transformers/finetune/sft_trainer.py | 7 ++++--- src/sparseml/transformers/finetune/text_generation.py | 6 ++++-- src/sparseml/transformers/sparsification/trainer.py | 4 ---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/sft_trainer.py index 5e6626f7362..0f14cfa38ae 100644 --- a/src/sparseml/transformers/finetune/sft_trainer.py +++ b/src/sparseml/transformers/finetune/sft_trainer.py @@ -18,11 +18,11 @@ import torch from torch.nn import Module -from trl import SFTTrainer as TRLSFTTrainer -from peft import PeftConfig from transformers.trainer_pt_utils import reissue_pt_warnings +from peft import PeftConfig from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn +from trl import SFTTrainer as TRLSFTTrainer __all__ = ["SFTTrainer"] @@ -68,6 +68,7 @@ def __init__( recipe=recipe, recipe_args=recipe_args, teacher=teacher, + peft_config=peft_config, **kwargs, ) @@ -112,4 +113,4 @@ def _prepare_dataset( append_concat_token=True, add_special_tokens=True, ): - return dataset \ No newline at end of file + return dataset diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index e084cd24f7f..74d10dce077 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -38,8 +38,9 @@ from sparseml.transformers.finetune.data.data_args import DataTrainingArguments from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.runner import StageRunner -from sparseml.transformers.finetune.trainer import Trainer from sparseml.transformers.finetune.sft_trainer import SFTTrainer + +# from sparseml.transformers.finetune.trainer import Trainer from sparseml.transformers.finetune.training_args import TrainingArguments from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src from sparseml.transformers.utils.helpers import detect_last_checkpoint @@ -334,6 +335,7 @@ def main( data_collator = DefaultDataCollator() from peft import LoraConfig + lora_config = LoraConfig( r=16, lora_alpha=32, @@ -356,7 +358,7 @@ def main( tokenizer=tokenizer, data_collator=data_collator, peft_config=lora_config, - dataset_text_field="text" + dataset_text_field="text", ) if trainer.is_fsdp_enabled: trainer._prepare_model_for_fsdp() diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 61393d31deb..69c8ac5616b 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -35,7 +35,6 @@ from transformers.integrations import TensorBoardCallback from transformers.trainer_callback import TrainerState from transformers.trainer_pt_utils import reissue_pt_warnings -#from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint from sparseml.pytorch.model_load.helpers import log_model_load from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer @@ -894,9 +893,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): if output_dir is None: output_dir = self.args.output_dir - #if self.sharded_ddp == ShardedDDPOption.SIMPLE and self.optimizer is not None: - # self.optimizer.consolidate_state_dict() - if self.is_world_process_zero(): if self.optimizer is not None: torch.save( From 086c2fffcf59a8eb64920392534fc1bc276f8549 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 1 Apr 2024 14:06:03 +0000 Subject: [PATCH 04/17] revert debugging changes --- .../transformers/finetune/text_generation.py | 19 ++----------------- .../transformers/sparsification/trainer.py | 4 ++++ 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index fa60c096bfa..6005c26f034 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -38,9 +38,7 @@ from sparseml.transformers.finetune.data.data_args import DataTrainingArguments from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.runner import StageRunner -from sparseml.transformers.finetune.sft_trainer import SFTTrainer - -# from sparseml.transformers.finetune.trainer import Trainer +from sparseml.transformers.finetune.trainer import Trainer from sparseml.transformers.finetune.training_args import TrainingArguments from sparseml.transformers.sparsification.sparse_model import ( SparseAutoModel, @@ -333,18 +331,7 @@ def main( # Initialize our Trainer data_collator = DefaultDataCollator() - - from peft import LoraConfig - - lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", - ) - - trainer = SFTTrainer( + trainer = Trainer( model_init=get_session_model, teacher=teacher, model_state_path=model_path, @@ -357,8 +344,6 @@ def main( eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, - peft_config=lora_config, - dataset_text_field="text", ) if trainer.is_fsdp_enabled: trainer._prepare_model_for_fsdp() diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 69c8ac5616b..bc45bec6d97 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -35,6 +35,7 @@ from transformers.integrations import TensorBoardCallback from transformers.trainer_callback import TrainerState from transformers.trainer_pt_utils import reissue_pt_warnings +from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint from sparseml.pytorch.model_load.helpers import log_model_load from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer @@ -893,6 +894,9 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): if output_dir is None: output_dir = self.args.output_dir + if self.sharded_ddp == ShardedDDPOption.SIMPLE and self.optimizer is not None: + self.optimizer.consolidate_state_dict() + if self.is_world_process_zero(): if self.optimizer is not None: torch.save( From 64096c71fee010650f44e5f788cdd94b404f8d47 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 1 Apr 2024 14:20:21 +0000 Subject: [PATCH 05/17] example script --- test_trl_trainer.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 test_trl_trainer.py diff --git a/test_trl_trainer.py b/test_trl_trainer.py new file mode 100644 index 00000000000..4f66ee5bf83 --- /dev/null +++ b/test_trl_trainer.py @@ -0,0 +1,44 @@ +from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer +from sparseml.transformers.finetune.sft_trainer import SFTTrainer +from transformers import DefaultDataCollator +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data import TextGenerationDataset +from peft import LoraConfig + +model_path = "facebook/opt-350m" +output_dir = "./output_trl_sft_test" + +model = SparseAutoModelForCausalLM.from_pretrained(model_path) +tokenizer = SparseAutoTokenizer.from_pretrained(model_path) + +data_args = DataTrainingArguments(dataset = "open_platypus") +dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train", + tokenizer=tokenizer, +) +raw_dataset = dataset_manager.get_raw_dataset() +train_dataset = dataset_manager.tokenize_and_process(raw_dataset) +print(f"--> Training Set Length = {len(train_dataset)}") + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +data_collator = DefaultDataCollator() +trainer = SFTTrainer( + model=model, + model_state_path=model_path, + train_dataset=train_dataset, + tokenizer=tokenizer, + data_collator=data_collator, + peft_config=lora_config, + dataset_text_field="text" +) +trainer.train() \ No newline at end of file From e243b46c875a5d2f732595beaf5eb48e75efa5eb Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 16:51:15 +0000 Subject: [PATCH 06/17] POC SFT sparse trainer --- .../transformers/finetune/__init__.py | 6 ++ .../transformers/finetune/data/__init__.py | 1 + .../transformers/finetune/data/base.py | 5 +- .../transformers/finetune/session_mixin.py | 28 ++---- .../transformers/finetune/sft_trainer.py | 97 +------------------ .../transformers/finetune/text_generation.py | 1 - src/sparseml/transformers/finetune/trainer.py | 53 ---------- .../transformers/finetune/training_args.py | 4 - test_trl_trainer.py | 52 ++++++---- 9 files changed, 54 insertions(+), 193 deletions(-) diff --git a/src/sparseml/transformers/finetune/__init__.py b/src/sparseml/transformers/finetune/__init__.py index 0995d54d97c..7fbbbcccab4 100644 --- a/src/sparseml/transformers/finetune/__init__.py +++ b/src/sparseml/transformers/finetune/__init__.py @@ -14,4 +14,10 @@ # flake8: noqa +from .data import DataTrainingArguments, TextGenerationDataset +from .model_args import ModelArguments +from .session_mixin import SessionManagerMixIn +from .sft_trainer import SFTTrainer from .text_generation import apply, compress, eval, oneshot, train +from .trainer import Trainer +from .training_args import TrainingArguments diff --git a/src/sparseml/transformers/finetune/data/__init__.py b/src/sparseml/transformers/finetune/data/__init__.py index c9c7cb9e509..f6ac4bbe1c6 100644 --- a/src/sparseml/transformers/finetune/data/__init__.py +++ b/src/sparseml/transformers/finetune/data/__init__.py @@ -18,6 +18,7 @@ from .c4 import C4Dataset from .cnn_dailymail import CNNDailyMailDataset from .custom import CustomDataset +from .data_args import DataTrainingArguments from .evolcodealpaca import EvolCodeAlpacaDataset from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 6f34bc352d5..354dfcccbe1 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -111,7 +111,7 @@ def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: **self.raw_kwargs, ) - def tokenize_and_process(self, raw_dataset: Dataset) -> Dataset: + def tokenize_and_process(self, raw_dataset: Optional[Dataset] = None) -> Dataset: """ Sets up the raw dataset for finetuning, performs tokenization, concatenates entries to max sequence length if desired, and adds labels to each entry @@ -168,6 +168,9 @@ def label_fn(data): data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding return data + if raw_dataset is None: + raw_dataset = self.get_raw_dataset() + dataset = self.map( raw_dataset, function=tokenize_fn, diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 72d18d98a9b..54585ff8045 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -30,8 +30,7 @@ from sparseml.core.session import callbacks from sparseml.pytorch.model_load.helpers import ( RECIPE_FILE_NAME, - get_session_model, - reload_model_state, + get_session_model ) from sparseml.pytorch.utils import LoggerManager, ModuleSparsificationInfo from sparseml.transformers.finetune.callbacks import ( @@ -56,7 +55,6 @@ class SessionManagerMixIn: Mix-In class to extend the Hugging Face Trainer class to support SparseML recipes for one-shot and finetuning flows. - :param model_state_path: path to Pytorch model checkpoint or saved model :param recipe: path to recipe file to apply during training :param recipe_args: additional kwargs to use for evaluating recipe :param metadata_args: additional kwargs for configuring training @@ -66,7 +64,6 @@ class SessionManagerMixIn: def __init__( self, - model_state_path: str, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, metadata_args: Optional[List[str]] = None, @@ -74,8 +71,6 @@ def __init__( teacher: Optional[Union[Module, str]] = None, **kwargs, ): - # instantiate necessary state, like managers, so we can override args - self.model_state_path = str(model_state_path) self.recipe = recipe self.recipe_args = recipe_args self.teacher = teacher @@ -134,7 +129,6 @@ def initialize_session( if session.lifecycle.initialized_ or session.lifecycle.finalized: return False - orig_state_dict = self.model.state_dict() train_data = self.get_train_dataloader() self.accelerator.wait_for_everyone() @@ -156,16 +150,6 @@ def initialize_session( model = get_session_model() self.model = model - # reload the state dict for the model now that architecture matches expected - # TODO: what if there is a quant modifier in the original recipe and we want to - # continue adjusting its zero point and range? - load_path = checkpoint or self.model_state_path - if reload_model_state(model, load_path, orig_state_dict): - _LOGGER.info( - "Reloaded model state after SparseML recipe structure modifications " - f"from {load_path}" - ) - if self.recipe is None: _LOGGER.warning( "No training recipe was provided, finetuning will be run " @@ -185,6 +169,9 @@ def initialize_structure(self, stage: Optional[str] = None): session = session_manager.active_session() if session.lifecycle.initialized_: return False + + if isinstance(self.model, str): + self.model_path_or_stub = self.model session_manager.pre_initialize_structure( model=self.model, @@ -479,9 +466,8 @@ def save_model( ) self.save_state() - self.tokenizer.save_pretrained(output_dir) - if not _is_oneshot: # optimizer/scheduler not relevant to one-shot - self.save_optimizer_and_scheduler(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) if not self.recipe: return @@ -504,7 +490,7 @@ def log_model_sparsification(self): sparsification_info = ModuleSparsificationInfo(self.model) _LOGGER.info( - f"Sparsification info for {self.model_state_path}: " + f"Sparsification info for {str(type(self.model))}: " f"{sparsification_info.params_total} total params. " f"Of those there are {sparsification_info.params_prunable_total} prunable " f"params which have {sparsification_info.params_prunable_sparse_percent} " diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/sft_trainer.py index 0f14cfa38ae..e7fca11a211 100644 --- a/src/sparseml/transformers/finetune/sft_trainer.py +++ b/src/sparseml/transformers/finetune/sft_trainer.py @@ -12,105 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings -from typing import Any, Callable, Dict, Optional, Union +from datasets import Dataset, IterableDataset -import torch -from torch.nn import Module -from transformers.trainer_pt_utils import reissue_pt_warnings - -from peft import PeftConfig from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn from trl import SFTTrainer as TRLSFTTrainer __all__ = ["SFTTrainer"] -TRAINER_STATE_NAME = "trainer_state.json" -OPTIMIZER_NAME = "optimizer.pt" -SCHEDULER_NAME = "scheduler.pt" -SCALER_NAME = "scaler.pt" - class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): - """ - Training implementation for running sparsification recipes with HF Trainer. - - :param model: the model to use with the trainer and apply sparsification to - :param model_state_path: the state path to the model, - used to load config and tokenizer settings - :param recipe: the recipe, if any, to apply to the modle and training - process - :param recipe_args: A json string, csv key=value string, or dictionary containing - arguments to override the root arguments within the recipe such as - learning rate or num epochs - :param teacher: teacher model for distillation. Set to 'self' to distill - from the loaded model or 'disable' to turn of distillation - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_state_path: str, - model: Optional[Module] = None, - model_init: Optional[Callable] = None, - recipe: Optional[str] = None, - recipe_args: Optional[Union[Dict[str, Any], str]] = None, - teacher: Optional[Union[Module, str]] = None, - peft_config: Optional[PeftConfig] = None, - **kwargs, - ): - super().__init__( - model=model, - model_init=model_init, - model_state_path=model_state_path, - recipe=recipe, - recipe_args=recipe_args, - teacher=teacher, - peft_config=peft_config, - **kwargs, - ) - - def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): - """ - Save optimizer, scheduler and scaler - - :param output_dir: The output model directory to save the above - """ - if output_dir is None: - output_dir = self.args.output_dir - - if self.is_world_process_zero(): - if self.optimizer is not None: - torch.save( - self.optimizer.state_dict(), - os.path.join(output_dir, "optimizer.pt"), - ) - with warnings.catch_warnings(record=True) as caught_warnings: - if self.lr_scheduler is not None: - torch.save( - self.lr_scheduler.state_dict(), - os.path.join(output_dir, "scheduler.pt"), - ) - reissue_pt_warnings(caught_warnings) - if self.use_cuda_amp: - torch.save( - self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") - ) + def _prepare_dataset(self, dataset, *args, **kwargs): + if isinstance(dataset, Dataset) or isinstance(dataset, IterableDataset): + return dataset - def _prepare_dataset( - self, - dataset, - tokenizer, - packing, - dataset_text_field, - max_seq_length, - formatting_func, - num_of_sequences, - chars_per_token, - remove_unused_columns=True, - append_concat_token=True, - add_special_tokens=True, - ): - return dataset + return super()._prepare_dataset(dataset, *args, **kwargs) diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index 6005c26f034..ebc11bffa8d 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -334,7 +334,6 @@ def main( trainer = Trainer( model_init=get_session_model, teacher=teacher, - model_state_path=model_path, recipe=training_args.recipe, metadata_args=metadata_args, recipe_args=training_args.recipe_args, diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index cf920e1feb6..94c9b51f94d 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -12,33 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings from typing import Any, Callable, Dict, Optional, Union import torch from torch.nn import Module from transformers import Trainer as HFTransformersTrainer -from transformers.trainer_pt_utils import reissue_pt_warnings from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn __all__ = ["Trainer"] -TRAINER_STATE_NAME = "trainer_state.json" -OPTIMIZER_NAME = "optimizer.pt" -SCHEDULER_NAME = "scheduler.pt" -SCALER_NAME = "scaler.pt" - class Trainer(SessionManagerMixIn, HFTransformersTrainer): """ Training implementation for running sparsification recipes with HF Trainer. :param model: the model to use with the trainer and apply sparsification to - :param model_state_path: the state path to the model, - used to load config and tokenizer settings :param recipe: the recipe, if any, to apply to the modle and training process :param recipe_args: A json string, csv key=value string, or dictionary containing @@ -51,7 +41,6 @@ class Trainer(SessionManagerMixIn, HFTransformersTrainer): def __init__( self, - model_state_path: str, model: Optional[Module] = None, model_init: Optional[Callable] = None, recipe: Optional[str] = None, @@ -62,54 +51,12 @@ def __init__( super().__init__( model=model, model_init=model_init, - model_state_path=model_state_path, recipe=recipe, recipe_args=recipe_args, teacher=teacher, **kwargs, ) - def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): - """ - Save optimizer, scheduler and scaler - - :param output_dir: The output model directory to save the above - """ - if output_dir is None: - output_dir = self.args.output_dir - - if self.is_world_process_zero(): - if self.optimizer is not None: - torch.save( - self.optimizer.state_dict(), - os.path.join(output_dir, "optimizer.pt"), - ) - with warnings.catch_warnings(record=True) as caught_warnings: - if self.lr_scheduler is not None: - torch.save( - self.lr_scheduler.state_dict(), - os.path.join(output_dir, "scheduler.pt"), - ) - reissue_pt_warnings(caught_warnings) - if self.use_cuda_amp: - torch.save( - self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") - ) - - def _save_checkpoint(self, model, trial, metrics=None): - # Call into the save checkpoint by HF Transformers, which saves the - # best metric if required - super()._save_checkpoint(model, trial, metrics=metrics) - if ( - self.args.metric_for_best_model is None - or self.args.best_model_after_epoch is None - ): - return - - if self.state.epoch <= self.args.best_model_after_epoch: - self.state.best_metric = None - self.state.best_model_checkpoint = None - def _dummy_lr_scheduler(self): return torch.optim.lr_scheduler.MultiplicativeLR( self.optimizer, diff --git a/src/sparseml/transformers/finetune/training_args.py b/src/sparseml/transformers/finetune/training_args.py index 083fb5c5e2b..49e4120572c 100644 --- a/src/sparseml/transformers/finetune/training_args.py +++ b/src/sparseml/transformers/finetune/training_args.py @@ -32,10 +32,6 @@ class TrainingArguments(HFTrainingArgs): arguments """ - best_model_after_epoch: int = field( - default=None, - metadata={"help": "Epoch after which best model will be saved."}, - ) recipe: Optional[str] = field( default=None, metadata={ diff --git a/test_trl_trainer.py b/test_trl_trainer.py index 4f66ee5bf83..c96c7a4bf59 100644 --- a/test_trl_trainer.py +++ b/test_trl_trainer.py @@ -1,14 +1,20 @@ -from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer -from sparseml.transformers.finetune.sft_trainer import SFTTrainer from transformers import DefaultDataCollator -from sparseml.transformers.finetune.data.data_args import DataTrainingArguments -from sparseml.transformers.finetune.data import TextGenerationDataset -from peft import LoraConfig +from datasets import load_dataset -model_path = "facebook/opt-350m" +from sparseml.transformers import ( + Trainer, + SFTTrainer, + DataTrainingArguments, + TrainingArguments, + TextGenerationDataset, + SparseAutoModelForCausalLM, + SparseAutoTokenizer +) + +model_path = "neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4" output_dir = "./output_trl_sft_test" -model = SparseAutoModelForCausalLM.from_pretrained(model_path) +model = SparseAutoModelForCausalLM.from_pretrained(model_path, device_map="auto") tokenizer = SparseAutoTokenizer.from_pretrained(model_path) data_args = DataTrainingArguments(dataset = "open_platypus") @@ -18,27 +24,31 @@ split="train", tokenizer=tokenizer, ) -raw_dataset = dataset_manager.get_raw_dataset() -train_dataset = dataset_manager.tokenize_and_process(raw_dataset) +train_dataset = dataset_manager.tokenize_and_process() print(f"--> Training Set Length = {len(train_dataset)}") +dataset = load_dataset("imdb", split="train") -lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", -) +recipe = """ +test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', + 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + start: 0 +""" data_collator = DefaultDataCollator() trainer = SFTTrainer( model=model, - model_state_path=model_path, - train_dataset=train_dataset, tokenizer=tokenizer, + recipe=recipe, + train_dataset=train_dataset, data_collator=data_collator, - peft_config=lora_config, - dataset_text_field="text" + args=TrainingArguments(output_dir=output_dir, num_train_epochs=0.01, logging_steps=50), + max_seq_length=data_args.max_seq_length, + packing=True + #dataset_text_field="text", ) -trainer.train() \ No newline at end of file +trainer.train() +trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file From c84cd604f7f09511b7c144817ece6145d081066b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 18:40:11 +0000 Subject: [PATCH 07/17] use sft data functionality --- .../transformers/finetune/session_mixin.py | 10 +--- .../transformers/finetune/sft_trainer.py | 5 +- test_trl_sft_data.py | 55 +++++++++++++++++++ test_trl_trainer.py | 21 +++---- 4 files changed, 70 insertions(+), 21 deletions(-) create mode 100644 test_trl_sft_data.py diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 54585ff8045..8b27cb7182c 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -28,10 +28,7 @@ import sparseml.core.session as session_manager from sparseml.core.framework import Framework from sparseml.core.session import callbacks -from sparseml.pytorch.model_load.helpers import ( - RECIPE_FILE_NAME, - get_session_model -) +from sparseml.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model from sparseml.pytorch.utils import LoggerManager, ModuleSparsificationInfo from sparseml.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, @@ -169,9 +166,6 @@ def initialize_structure(self, stage: Optional[str] = None): session = session_manager.active_session() if session.lifecycle.initialized_: return False - - if isinstance(self.model, str): - self.model_path_or_stub = self.model session_manager.pre_initialize_structure( model=self.model, @@ -490,7 +484,7 @@ def log_model_sparsification(self): sparsification_info = ModuleSparsificationInfo(self.model) _LOGGER.info( - f"Sparsification info for {str(type(self.model))}: " + f"Sparsification info for {type(self.model).__name__}: " f"{sparsification_info.params_total} total params. " f"Of those there are {sparsification_info.params_prunable_total} prunable " f"params which have {sparsification_info.params_prunable_sparse_percent} " diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/sft_trainer.py index e7fca11a211..bc455b56aa1 100644 --- a/src/sparseml/transformers/finetune/sft_trainer.py +++ b/src/sparseml/transformers/finetune/sft_trainer.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datasets import Dataset, IterableDataset - from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn from trl import SFTTrainer as TRLSFTTrainer @@ -23,7 +21,8 @@ class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): def _prepare_dataset(self, dataset, *args, **kwargs): - if isinstance(dataset, Dataset) or isinstance(dataset, IterableDataset): + if "input_ids" in dataset.column_names: + # dataset is already tokenized, skip preprocessing return dataset return super()._prepare_dataset(dataset, *args, **kwargs) diff --git a/test_trl_sft_data.py b/test_trl_sft_data.py new file mode 100644 index 00000000000..4f19ac474b6 --- /dev/null +++ b/test_trl_sft_data.py @@ -0,0 +1,55 @@ +from datasets import load_dataset +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +from sparseml.transformers import ( + SFTTrainer, + TrainingArguments, + SparseAutoModelForCausalLM, + SparseAutoTokenizer +) + +dataset = load_dataset("gsm8k", "main", split="train") +model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" +output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" +model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") +tokenizer = SparseAutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token + +recipe = """ +test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', + 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + start: 0 +""" + + +def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example['question'])): + text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}" + output_texts.append(text) + return output_texts + +response_template = "Answer:" +collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) +training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True +) + +trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + recipe=recipe, + train_dataset=dataset, + formatting_func=formatting_prompts_func, + data_collator=collator, + args=training_args, + max_seq_length=512 +) +trainer.train() +trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file diff --git a/test_trl_trainer.py b/test_trl_trainer.py index c96c7a4bf59..dce588fb418 100644 --- a/test_trl_trainer.py +++ b/test_trl_trainer.py @@ -1,8 +1,6 @@ from transformers import DefaultDataCollator -from datasets import load_dataset from sparseml.transformers import ( - Trainer, SFTTrainer, DataTrainingArguments, TrainingArguments, @@ -11,13 +9,13 @@ SparseAutoTokenizer ) -model_path = "neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4" -output_dir = "./output_trl_sft_test" +model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" +output_dir = "./output_trl_sft_test_7b_gsm8k" -model = SparseAutoModelForCausalLM.from_pretrained(model_path, device_map="auto") +model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") tokenizer = SparseAutoTokenizer.from_pretrained(model_path) -data_args = DataTrainingArguments(dataset = "open_platypus") +data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main") dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,8 +25,6 @@ train_dataset = dataset_manager.tokenize_and_process() print(f"--> Training Set Length = {len(train_dataset)}") -dataset = load_dataset("imdb", split="train") - recipe = """ test_stage: pruning_modifiers: @@ -39,16 +35,21 @@ """ data_collator = DefaultDataCollator() +training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True +) trainer = SFTTrainer( model=model, tokenizer=tokenizer, recipe=recipe, train_dataset=train_dataset, data_collator=data_collator, - args=TrainingArguments(output_dir=output_dir, num_train_epochs=0.01, logging_steps=50), + args=training_args, max_seq_length=data_args.max_seq_length, packing=True - #dataset_text_field="text", ) trainer.train() trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file From 4f938cb6934430cd85934bdd7ea5d10c6c37915b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 21:10:01 +0000 Subject: [PATCH 08/17] update unit tests --- test_trl_sft_data.py | 3 ++- test_trl_trainer.py | 2 +- tests/sparseml/transformers/finetune/test_session_mixin.py | 7 +------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test_trl_sft_data.py b/test_trl_sft_data.py index 4f19ac474b6..7d7109a16b0 100644 --- a/test_trl_sft_data.py +++ b/test_trl_sft_data.py @@ -8,7 +8,6 @@ SparseAutoTokenizer ) -dataset = load_dataset("gsm8k", "main", split="train") model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") @@ -25,6 +24,7 @@ """ +dataset = load_dataset("gsm8k", "main", split="train") def formatting_prompts_func(example): output_texts = [] for i in range(len(example['question'])): @@ -34,6 +34,7 @@ def formatting_prompts_func(example): response_template = "Answer:" collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=0.6, diff --git a/test_trl_trainer.py b/test_trl_trainer.py index dce588fb418..3d0f8a42ad4 100644 --- a/test_trl_trainer.py +++ b/test_trl_trainer.py @@ -15,7 +15,7 @@ model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") tokenizer = SparseAutoTokenizer.from_pretrained(model_path) -data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main") +data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512) dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/sparseml/transformers/finetune/test_session_mixin.py b/tests/sparseml/transformers/finetune/test_session_mixin.py index 24af397920b..cc74de299e0 100644 --- a/tests/sparseml/transformers/finetune/test_session_mixin.py +++ b/tests/sparseml/transformers/finetune/test_session_mixin.py @@ -27,7 +27,6 @@ class MixInTest(SessionManagerMixIn, Trainer): def __init__( self, model: Module, - model_state_path: str, recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, @@ -35,7 +34,6 @@ def __init__( ): super().__init__( model=model, - model_state_path=model_state_path, recipe=recipe, recipe_args=recipe_args, teacher=teacher, @@ -48,9 +46,7 @@ def test_mixin_init(): model = AutoModelForCausalLM.from_pretrained(model_state_path) recipe = "tests/sparseml/transformers/finetune/test_quantization.yaml" - session_mixin = MixInTest( - model=model, model_state_path=model_state_path, recipe=recipe - ) + session_mixin = MixInTest(model=model, recipe=recipe) assert isinstance(session_mixin, SessionManagerMixIn) assert isinstance(session_mixin, Trainer) assert session_mixin.recipe == recipe @@ -67,7 +63,6 @@ def mixin_trainer(): return MixInTest( model=model, - model_state_path=model_state_path, recipe=recipe, train_dataset=train_dataset, eval_dataset=eval_dataset, From 382160657a03c558b3ed2c265cac98c92c85b462 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 21:13:02 +0000 Subject: [PATCH 09/17] move examples folder --- src/sparseml/export/validators.py | 3 ++- src/sparseml/transformers/finetune/README.md | 4 ++-- .../finetune/{ => examples}/example_alternating_recipe.yaml | 0 .../transformers/finetune/examples/test_trl_sft_data.py | 0 .../transformers/finetune/examples/test_trl_trainer.py | 0 5 files changed, 4 insertions(+), 3 deletions(-) rename src/sparseml/transformers/finetune/{ => examples}/example_alternating_recipe.yaml (100%) rename test_trl_sft_data.py => src/sparseml/transformers/finetune/examples/test_trl_sft_data.py (100%) rename test_trl_trainer.py => src/sparseml/transformers/finetune/examples/test_trl_trainer.py (100%) diff --git a/src/sparseml/export/validators.py b/src/sparseml/export/validators.py index 52c9fa05ee0..f513bda21aa 100644 --- a/src/sparseml/export/validators.py +++ b/src/sparseml/export/validators.py @@ -17,8 +17,9 @@ import os.path from collections import OrderedDict from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional from typing import OrderedDict as OrderedDictType +from typing import Union import numpy import onnx diff --git a/src/sparseml/transformers/finetune/README.md b/src/sparseml/transformers/finetune/README.md index dc3a61e9ba3..7022b9ccc10 100644 --- a/src/sparseml/transformers/finetune/README.md +++ b/src/sparseml/transformers/finetune/README.md @@ -132,7 +132,7 @@ A recipe can be run stage-by-stage by setting `run_stages` to `True` or calling a `run_type` attribute set to either `oneshot` or `train` when running in sequential mode. -See [example_alternating_recipe.yaml](example_alternating_recipe.yaml) for an example +See [example_alternating_recipe.yaml](examples/example_alternating_recipe.yaml) for an example of a staged recipe for Llama. ### Python Example @@ -147,7 +147,7 @@ dataset_name = "open_platypus" concatenate_data = False run_stages=True output_dir = "./output_finetune_multi" -recipe = "example_alternating_recipe.yaml" +recipe = "examples/example_alternating_recipe.yaml" num_train_epochs=1 overwrite_output_dir = True splits = { diff --git a/src/sparseml/transformers/finetune/example_alternating_recipe.yaml b/src/sparseml/transformers/finetune/examples/example_alternating_recipe.yaml similarity index 100% rename from src/sparseml/transformers/finetune/example_alternating_recipe.yaml rename to src/sparseml/transformers/finetune/examples/example_alternating_recipe.yaml diff --git a/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py similarity index 100% rename from test_trl_sft_data.py rename to src/sparseml/transformers/finetune/examples/test_trl_sft_data.py diff --git a/test_trl_trainer.py b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py similarity index 100% rename from test_trl_trainer.py rename to src/sparseml/transformers/finetune/examples/test_trl_trainer.py From 4f619bdb9e9188e535728cd8e59afe02db38f1ac Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 21:17:17 +0000 Subject: [PATCH 10/17] clarity comments --- .../transformers/finetune/examples/test_trl_sft_data.py | 4 ++-- .../transformers/finetune/examples/test_trl_trainer.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py index 7d7109a16b0..68292fc85cf 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py @@ -14,6 +14,7 @@ tokenizer = SparseAutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token +# recipe for maintaining model sparsity during finetuning recipe = """ test_stage: pruning_modifiers: @@ -23,7 +24,7 @@ start: 0 """ - +# Load gsm8k using TRL dataset tools dataset = load_dataset("gsm8k", "main", split="train") def formatting_prompts_func(example): output_texts = [] @@ -31,7 +32,6 @@ def formatting_prompts_func(example): text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}" output_texts.append(text) return output_texts - response_template = "Answer:" collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py index 3d0f8a42ad4..c1af0b67500 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py @@ -15,6 +15,7 @@ model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") tokenizer = SparseAutoTokenizer.from_pretrained(model_path) +# Load gsm8k using SparseML dataset tools data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512) dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, @@ -25,6 +26,7 @@ train_dataset = dataset_manager.tokenize_and_process() print(f"--> Training Set Length = {len(train_dataset)}") +# recipe for maintaining model sparsity during finetuning recipe = """ test_stage: pruning_modifiers: From e425523b306730ee4837013ee7492d906b591de6 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 2 Apr 2024 21:27:27 +0000 Subject: [PATCH 11/17] barest bones trainer --- src/sparseml/transformers/finetune/trainer.py | 38 +------------------ 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index 94c9b51f94d..9fed69d82be 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -25,40 +25,4 @@ class Trainer(SessionManagerMixIn, HFTransformersTrainer): - """ - Training implementation for running sparsification recipes with HF Trainer. - - :param model: the model to use with the trainer and apply sparsification to - :param recipe: the recipe, if any, to apply to the modle and training - process - :param recipe_args: A json string, csv key=value string, or dictionary containing - arguments to override the root arguments within the recipe such as - learning rate or num epochs - :param teacher: teacher model for distillation. Set to 'self' to distill - from the loaded model or 'disable' to turn of distillation - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model: Optional[Module] = None, - model_init: Optional[Callable] = None, - recipe: Optional[str] = None, - recipe_args: Optional[Union[Dict[str, Any], str]] = None, - teacher: Optional[Union[Module, str]] = None, - **kwargs, - ): - super().__init__( - model=model, - model_init=model_init, - recipe=recipe, - recipe_args=recipe_args, - teacher=teacher, - **kwargs, - ) - - def _dummy_lr_scheduler(self): - return torch.optim.lr_scheduler.MultiplicativeLR( - self.optimizer, - lambda _: 1.0, - ) + pass \ No newline at end of file From 3c9a2e311132ea966be822861bdd33b4f34afd8e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 3 Apr 2024 12:46:34 +0000 Subject: [PATCH 12/17] style --- .../finetune/examples/test_trl_sft_data.py | 50 +++++++++++++------ .../finetune/examples/test_trl_trainer.py | 50 +++++++++++++------ src/sparseml/transformers/finetune/trainer.py | 6 +-- 3 files changed, 72 insertions(+), 34 deletions(-) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py index 68292fc85cf..ddbbd7ae623 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py @@ -1,16 +1,33 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from datasets import load_dataset -from trl import SFTTrainer, DataCollatorForCompletionOnlyLM from sparseml.transformers import ( SFTTrainer, - TrainingArguments, - SparseAutoModelForCausalLM, - SparseAutoTokenizer + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TrainingArguments, ) +from trl import DataCollatorForCompletionOnlyLM + model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" -model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) tokenizer = SparseAutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token @@ -19,27 +36,32 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', - 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] start: 0 """ # Load gsm8k using TRL dataset tools dataset = load_dataset("gsm8k", "main", split="train") + + def formatting_prompts_func(example): output_texts = [] - for i in range(len(example['question'])): + for i in range(len(example["question"])): text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}" output_texts.append(text) return output_texts + + response_template = "Answer:" collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) training_args = TrainingArguments( - output_dir=output_dir, - num_train_epochs=0.6, - logging_steps=50, - gradient_checkpointing=True + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, ) trainer = SFTTrainer( @@ -50,7 +72,7 @@ def formatting_prompts_func(example): formatting_func=formatting_prompts_func, data_collator=collator, args=training_args, - max_seq_length=512 + max_seq_length=512, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file +trainer.save_model(output_dir=trainer.args.output_dir) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py index c1af0b67500..7b4ecda49b5 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py @@ -1,22 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from transformers import DefaultDataCollator from sparseml.transformers import ( + DataTrainingArguments, SFTTrainer, - DataTrainingArguments, - TrainingArguments, - TextGenerationDataset, - SparseAutoModelForCausalLM, - SparseAutoTokenizer + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TextGenerationDataset, + TrainingArguments, ) + model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k" -model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) tokenizer = SparseAutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512) +data_args = DataTrainingArguments( + dataset="gsm8k", dataset_config_name="main", max_seq_length=512 +) dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -31,17 +50,18 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', - 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] start: 0 """ data_collator = DefaultDataCollator() training_args = TrainingArguments( - output_dir=output_dir, - num_train_epochs=0.6, - logging_steps=50, - gradient_checkpointing=True + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, ) trainer = SFTTrainer( model=model, @@ -51,7 +71,7 @@ data_collator=data_collator, args=training_args, max_seq_length=data_args.max_seq_length, - packing=True + packing=True, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file +trainer.save_model(output_dir=trainer.args.output_dir) diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index 9fed69d82be..d918f50880b 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.nn import Module from transformers import Trainer as HFTransformersTrainer from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn @@ -25,4 +21,4 @@ class Trainer(SessionManagerMixIn, HFTransformersTrainer): - pass \ No newline at end of file + pass From 2ed1bb281ba2e16d3f1e5e6721c30e3508717ccb Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 3 Apr 2024 19:12:33 +0000 Subject: [PATCH 13/17] tweaks to work with distillation and FSDP --- ...rl_trainer.py => test_trl_distillation.py} | 18 +++++++++++++-- .../finetune/examples/test_trl_sft_data.py | 4 ++-- .../transformers/finetune/session_mixin.py | 23 +++++++++++++------ .../transformers/finetune/text_generation.py | 15 ------------ 4 files changed, 34 insertions(+), 26 deletions(-) rename src/sparseml/transformers/finetune/examples/{test_trl_trainer.py => test_trl_distillation.py} (82%) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py b/src/sparseml/transformers/finetune/examples/test_trl_distillation.py similarity index 82% rename from src/sparseml/transformers/finetune/examples/test_trl_trainer.py rename to src/sparseml/transformers/finetune/examples/test_trl_distillation.py index 7b4ecda49b5..0b8e737edda 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_distillation.py @@ -25,11 +25,16 @@ model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" +teacher_path = "zoo:llama2-7b-gsm8k_llama2_pretrain-base" output_dir = "./output_trl_sft_test_7b_gsm8k" model = SparseAutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto" ) +teacher = SparseAutoModelForCausalLM.from_pretrained( + teacher_path, torch_dtype="auto", device_map="auto" +) + tokenizer = SparseAutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools @@ -50,10 +55,16 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', - 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight', 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] start: 0 + OutputDistillationModifier: + targets: ['re:model.layers.\\d+$'] + comparison: "square_head" + start: 0 + orig_scale: 1.0 + distill_scale: 1.0 """ data_collator = DefaultDataCollator() @@ -62,14 +73,17 @@ num_train_epochs=0.6, logging_steps=50, gradient_checkpointing=True, + bf16=True, ) trainer = SFTTrainer( model=model, + teacher=teacher, tokenizer=tokenizer, recipe=recipe, train_dataset=train_dataset, data_collator=data_collator, args=training_args, + data_args=data_args, max_seq_length=data_args.max_seq_length, packing=True, ) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py index ddbbd7ae623..a4109ce9901 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py @@ -36,8 +36,8 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', - 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] start: 0 """ diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 8b27cb7182c..0d3d70baa2e 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -45,6 +45,13 @@ _LOGGER = logging.getLogger(__name__) TRAINER_STATE_NAME = "trainer_state.json" +METADATA_ARGS = [ + "per_device_train_batch_size", + "per_device_eval_batch_size", + "max_seq_length", + "save_safetensors", + "fp16", +] class SessionManagerMixIn: @@ -54,7 +61,6 @@ class SessionManagerMixIn: :param recipe: path to recipe file to apply during training :param recipe_args: additional kwargs to use for evaluating recipe - :param metadata_args: additional kwargs for configuring training :param data_args: kwargs for configuring dataset loading :param teacher: optional teacher model to use for distillation """ @@ -63,7 +69,6 @@ def __init__( self, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, - metadata_args: Optional[List[str]] = None, data_args: Optional["DataTrainingArguments"] = None, # noqa: F821 teacher: Optional[Union[Module, str]] = None, **kwargs, @@ -76,11 +81,11 @@ def __init__( training_args = kwargs.get("args") self.metadata = ( self._extract_metadata( - metadata_args=metadata_args, + metadata_args=METADATA_ARGS, training_args_dict=training_args.to_dict(), data_args_dict=asdict(data_args) if data_args else {}, ) - if training_args and metadata_args + if training_args and METADATA_ARGS else None ) @@ -90,6 +95,7 @@ def __init__( # call Trainer initialization super().__init__(**kwargs) + self.accelerator.wait_for_everyone() # setup callbacks and loss self.optim_callbacks = TrainingLoopCallbacks(self) @@ -99,7 +105,7 @@ def __init__( self.criterion = torch.nn.CrossEntropyLoss() model_signature = inspect.signature(self.model.forward) - self._model_signature_columns = list(model_signature.parameters.keys()) + self._signature_columns = list(model_signature.parameters.keys()) if self.teacher is not None and teacher not in ("disable", "self"): teacher_signature = inspect.signature(self.teacher.forward) @@ -107,6 +113,9 @@ def __init__( else: self._teacher_signature_columns = None + if self.is_fsdp_enabled: + self._prepare_model_for_fsdp() + def initialize_session( self, epoch: float, @@ -145,7 +154,7 @@ def initialize_session( ) self.accelerator.wait_for_everyone() model = get_session_model() - self.model = model + self.model_wrapped = self.model = model if self.recipe is None: _LOGGER.warning( @@ -279,7 +288,7 @@ def compute_loss( self._check_super_defined("compute_loss") # TODO: do we need these model signature columns? - inputs = {k: inputs[k] for k in inputs if k in self._model_signature_columns} + inputs = {k: inputs[k] for k in inputs if k in self._signature_columns} loss = super().compute_loss(model, inputs, return_outputs=return_outputs) # take the mean across multiple GPUs diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index ebc11bffa8d..c33c5c18559 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -49,14 +49,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -metadata_args = [ - "per_device_train_batch_size", - "per_device_eval_batch_size", - "max_seq_length", - "save_safetensors", - "fp16", -] - def train(**kwargs): """ @@ -134,10 +126,6 @@ def parse_args(**kwargs): arg_dict[key] = value training_args.recipe_args = arg_dict - # when set to true in FSDP mode this causes issues, the model arguments show up - # as *args and **kwargs so all columns get removed - training_args.remove_unused_columns = False - return model_args, data_args, training_args @@ -335,7 +323,6 @@ def main( model_init=get_session_model, teacher=teacher, recipe=training_args.recipe, - metadata_args=metadata_args, recipe_args=training_args.recipe_args, args=training_args, data_args=data_args, @@ -344,8 +331,6 @@ def main( tokenizer=tokenizer, data_collator=data_collator, ) - if trainer.is_fsdp_enabled: - trainer._prepare_model_for_fsdp() stage_runner.trainer = trainer # alternating Training/One-shot From 1f61081096787fd10ae70d86f1fc778f4c6e5090 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 9 Apr 2024 19:02:03 +0000 Subject: [PATCH 14/17] tests and readme --- .../transformers/finetune/__init__.py | 1 - .../finetune/examples/trl_mixin/README.md | 27 +++++++++++++++++++ .../{ => examples/trl_mixin}/sft_trainer.py | 0 .../{ => trl_mixin}/test_trl_distillation.py | 2 +- .../{ => trl_mixin}/test_trl_sft_data.py | 2 +- 5 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 src/sparseml/transformers/finetune/examples/trl_mixin/README.md rename src/sparseml/transformers/finetune/{ => examples/trl_mixin}/sft_trainer.py (100%) rename src/sparseml/transformers/finetune/examples/{ => trl_mixin}/test_trl_distillation.py (98%) rename src/sparseml/transformers/finetune/examples/{ => trl_mixin}/test_trl_sft_data.py (97%) diff --git a/src/sparseml/transformers/finetune/__init__.py b/src/sparseml/transformers/finetune/__init__.py index 7fbbbcccab4..6329b5a4692 100644 --- a/src/sparseml/transformers/finetune/__init__.py +++ b/src/sparseml/transformers/finetune/__init__.py @@ -17,7 +17,6 @@ from .data import DataTrainingArguments, TextGenerationDataset from .model_args import ModelArguments from .session_mixin import SessionManagerMixIn -from .sft_trainer import SFTTrainer from .text_generation import apply, compress, eval, oneshot, train from .trainer import Trainer from .training_args import TrainingArguments diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/README.md b/src/sparseml/transformers/finetune/examples/trl_mixin/README.md new file mode 100644 index 00000000000..39e51b63d60 --- /dev/null +++ b/src/sparseml/transformers/finetune/examples/trl_mixin/README.md @@ -0,0 +1,27 @@ +# Sparse Finetuning with TRL's SFTTrainer + +The `SessionManagerMixin` can be added to other Trainer classes that inherit from +[Hugging Face's Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer). + +For example, we can add SparseML support to TRL's SFTTrainer like so: + +```python +from trl import SFTTrainer as TRLSFTTrainer + +class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): + ... +``` + +The new `SFTTrainer` class can now apply SparseML recipes and modifiers during +supervised finetuning, will full support for all of the original TRL features. + +### Examples + +[test_trl_sft_data.py](test_trl_sft_data.py): finetunes a 50% sparse Llama-7b model, +using TRL's dataset preprocessing. Sparsity is maintained throughout training by +applying a `ConstantPruningModifier` recipe to the `SFTTrainer` + +[test_trl_distillation.py](test_trl_distillation.py): finetunes a 50% sparse Llama-7b +model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained +throughout training with a `ConstantPruningModifier` and layer-wise knowledge +distillation is handled by the `OutputDistillationModifier` \ No newline at end of file diff --git a/src/sparseml/transformers/finetune/sft_trainer.py b/src/sparseml/transformers/finetune/examples/trl_mixin/sft_trainer.py similarity index 100% rename from src/sparseml/transformers/finetune/sft_trainer.py rename to src/sparseml/transformers/finetune/examples/trl_mixin/sft_trainer.py diff --git a/src/sparseml/transformers/finetune/examples/test_trl_distillation.py b/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_distillation.py similarity index 98% rename from src/sparseml/transformers/finetune/examples/test_trl_distillation.py rename to src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_distillation.py index 0b8e737edda..c5a2ac07c1b 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_distillation.py +++ b/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_distillation.py @@ -88,4 +88,4 @@ packing=True, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) +trainer.save_model() diff --git a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_sft_data.py similarity index 97% rename from src/sparseml/transformers/finetune/examples/test_trl_sft_data.py rename to src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_sft_data.py index a4109ce9901..bc87338e191 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py +++ b/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_sft_data.py @@ -75,4 +75,4 @@ def formatting_prompts_func(example): max_seq_length=512, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) +trainer.save_model() From 44cb166af2bb2557559f0c7c5aba9cb60ecf7fb0 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 9 Apr 2024 19:16:26 +0000 Subject: [PATCH 15/17] naming --- .../transformers/finetune/examples/trl_mixin/README.md | 9 ++++++--- .../{test_trl_sft_data.py => ex_trl_constant.py} | 0 .../{test_trl_distillation.py => ex_trl_distillation.py} | 0 3 files changed, 6 insertions(+), 3 deletions(-) rename src/sparseml/transformers/finetune/examples/trl_mixin/{test_trl_sft_data.py => ex_trl_constant.py} (100%) rename src/sparseml/transformers/finetune/examples/trl_mixin/{test_trl_distillation.py => ex_trl_distillation.py} (100%) diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/README.md b/src/sparseml/transformers/finetune/examples/trl_mixin/README.md index 39e51b63d60..64ad73e3510 100644 --- a/src/sparseml/transformers/finetune/examples/trl_mixin/README.md +++ b/src/sparseml/transformers/finetune/examples/trl_mixin/README.md @@ -13,15 +13,18 @@ class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): ``` The new `SFTTrainer` class can now apply SparseML recipes and modifiers during -supervised finetuning, will full support for all of the original TRL features. +supervised finetuning, will full support for all of the original TRL features. The full +class is defined in [sft_trainer.py](sft_trainer.py) and requires very minimal +additional code: just a dataset load override to support passing in tokenized datasets +to the Trainer. ### Examples -[test_trl_sft_data.py](test_trl_sft_data.py): finetunes a 50% sparse Llama-7b model, +[ex_trl_sft_data.py](ex_trl_sft_data.py): finetunes a 50% sparse Llama-7b model, using TRL's dataset preprocessing. Sparsity is maintained throughout training by applying a `ConstantPruningModifier` recipe to the `SFTTrainer` -[test_trl_distillation.py](test_trl_distillation.py): finetunes a 50% sparse Llama-7b +[ex_trl_distillation.py](ex_trl_distillation.py): finetunes a 50% sparse Llama-7b model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained throughout training with a `ConstantPruningModifier` and layer-wise knowledge distillation is handled by the `OutputDistillationModifier` \ No newline at end of file diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_constant.py similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_sft_data.py rename to src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_constant.py diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_distillation.py b/src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_distillation.py similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/test_trl_distillation.py rename to src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_distillation.py From 0051eaf2c6aabf97e26f5bf504460e9ead44b728 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 16 Apr 2024 03:20:15 +0000 Subject: [PATCH 16/17] move examples --- .../text-generation}/example_alternating_recipe.yaml | 0 .../tutorials/text-generation}/trl_mixin/README.md | 0 .../tutorials/text-generation}/trl_mixin/ex_trl_constant.py | 0 .../text-generation}/trl_mixin/ex_trl_distillation.py | 0 .../tutorials/text-generation}/trl_mixin/sft_trainer.py | 0 src/sparseml/transformers/finetune/README.md | 4 ++-- 6 files changed, 2 insertions(+), 2 deletions(-) rename {src/sparseml/transformers/finetune/examples => integrations/huggingface-transformers/tutorials/text-generation}/example_alternating_recipe.yaml (100%) rename {src/sparseml/transformers/finetune/examples => integrations/huggingface-transformers/tutorials/text-generation}/trl_mixin/README.md (100%) rename {src/sparseml/transformers/finetune/examples => integrations/huggingface-transformers/tutorials/text-generation}/trl_mixin/ex_trl_constant.py (100%) rename {src/sparseml/transformers/finetune/examples => integrations/huggingface-transformers/tutorials/text-generation}/trl_mixin/ex_trl_distillation.py (100%) rename {src/sparseml/transformers/finetune/examples => integrations/huggingface-transformers/tutorials/text-generation}/trl_mixin/sft_trainer.py (100%) diff --git a/src/sparseml/transformers/finetune/examples/example_alternating_recipe.yaml b/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml similarity index 100% rename from src/sparseml/transformers/finetune/examples/example_alternating_recipe.yaml rename to integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/README.md b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/README.md rename to integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_constant.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_constant.py similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_constant.py rename to integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_constant.py diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_distillation.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_distillation.py similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/ex_trl_distillation.py rename to integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/ex_trl_distillation.py diff --git a/src/sparseml/transformers/finetune/examples/trl_mixin/sft_trainer.py b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/sft_trainer.py similarity index 100% rename from src/sparseml/transformers/finetune/examples/trl_mixin/sft_trainer.py rename to integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/sft_trainer.py diff --git a/src/sparseml/transformers/finetune/README.md b/src/sparseml/transformers/finetune/README.md index 7022b9ccc10..aaee586d671 100644 --- a/src/sparseml/transformers/finetune/README.md +++ b/src/sparseml/transformers/finetune/README.md @@ -132,7 +132,7 @@ A recipe can be run stage-by-stage by setting `run_stages` to `True` or calling a `run_type` attribute set to either `oneshot` or `train` when running in sequential mode. -See [example_alternating_recipe.yaml](examples/example_alternating_recipe.yaml) for an example +See [example_alternating_recipe.yaml](../../../../integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml) for an example of a staged recipe for Llama. ### Python Example @@ -147,7 +147,7 @@ dataset_name = "open_platypus" concatenate_data = False run_stages=True output_dir = "./output_finetune_multi" -recipe = "examples/example_alternating_recipe.yaml" +recipe = "example_alternating_recipe.yaml" num_train_epochs=1 overwrite_output_dir = True splits = { From 7630dd7ce7b1019c4601f5be0416dbe3461ae80b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 16 Apr 2024 14:48:37 +0000 Subject: [PATCH 17/17] quality --- .../text-generation/trl_mixin/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md index 64ad73e3510..25c3b54976b 100644 --- a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md @@ -1,3 +1,19 @@ + + # Sparse Finetuning with TRL's SFTTrainer The `SessionManagerMixin` can be added to other Trainer classes that inherit from