From b09c6d0efdc44e8d793568155fbdf75ebb76e8ab Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Wed, 2 Feb 2022 16:52:46 -0500 Subject: [PATCH] [cherry-pick] transformers refactor (#538) * Refactor of Transformers SparseML CLI and integrations (#536) * Refactor of Transformers SparseML CLI and integrations * Refactor export.py to use new pathways, fix make quality * Update src/sparseml/optim/manager.py Co-authored-by: Rahul Tuli * Update src/sparseml/transformers/utils/model.py Co-authored-by: Rahul Tuli * fixes from review * fixes from review and testing * bug fixes and logging * bug fixes for export and distillation * review fixes, quality fixes, style fixes * fix dependency issue * fix distillation tests * fix distillation tests * fix distillation tests * fill in docs and update style * fix issue with distillation improperly updating students inputs * fix quality * Update src/sparseml/pytorch/optim/modifier_distillation.py * add in better logging for missing and unexpected keys in model reload for transformers trainer * fix logging for transformers export Co-authored-by: Rahul Tuli * Fix model load bug and add logging to catch potential future issues (#537) * Fix model load bug and add logging to catch potential future issues * initial migration to generalize module sparsification information * propagate ModuleSparsificationInfo * report type of input tensors in export.py * minor bug fixes * ModuleSparsificationInfo docs * export onnx bugfix * bug fixes * make style * bug fix for quantization * revert to use ScheduledOptimizer due to bug with torch LambdaLR * remove language_modeling script * add end model sparsification log Co-authored-by: Benjamin Co-authored-by: Mark Kurtz Co-authored-by: Rahul Tuli --- setup.py | 58 +- src/sparseml/keras/optim/manager.py | 7 +- src/sparseml/optim/helpers.py | 60 +- src/sparseml/optim/manager.py | 89 ++- src/sparseml/optim/modifier.py | 8 + src/sparseml/pytorch/optim/manager.py | 46 +- src/sparseml/pytorch/optim/modifier.py | 31 + src/sparseml/pytorch/optim/modifier_as.py | 9 + .../pytorch/optim/modifier_distillation.py | 184 +++-- src/sparseml/pytorch/optim/modifier_lr.py | 9 + src/sparseml/pytorch/optim/modifier_params.py | 17 +- .../pytorch/optim/modifier_pruning.py | 16 + .../pytorch/optim/modifier_quantization.py | 10 +- .../pytorch/optim/modifier_regularizer.py | 10 +- .../pytorch/optim/modifier_thinning.py | 10 +- src/sparseml/pytorch/utils/__init__.py | 1 + src/sparseml/pytorch/utils/helpers.py | 73 +- src/sparseml/pytorch/utils/sparsification.py | 169 +++++ src/sparseml/sparsification/__init__.py | 1 + src/sparseml/sparsification/modifier_epoch.py | 10 + src/sparseml/sparsification/modifier_lr.py | 15 + .../sparsification/modifier_params.py | 8 + .../sparsification/modifier_pruning.py | 15 + src/sparseml/sparsification/types.py | 40 + src/sparseml/tensorflow_v1/optim/manager.py | 14 +- src/sparseml/transformers/__init__.py | 4 +- .../transformers/{utils => }/export.py | 136 ++-- .../{train => }/question_answering.py | 69 +- .../{train => sparsification}/__init__.py | 6 +- .../question_answering.py | 47 +- .../transformers/sparsification/trainer.py | 671 +++++++++++++++++ .../{train => }/text_classification.py | 69 +- .../{train => }/token_classification.py | 71 +- .../transformers/train/language_modeling.py | 700 ------------------ src/sparseml/transformers/utils/__init__.py | 9 +- src/sparseml/transformers/utils/helpers.py | 73 +- .../transformers/utils/language_modeling.py | 66 -- src/sparseml/transformers/utils/model.py | 383 ++++++++++ .../transformers/utils/text_classification.py | 66 -- .../utils/token_classification.py | 66 -- src/sparseml/transformers/utils/trainer.py | 333 --------- tests/sparseml/pytorch/optim/test_modifier.py | 9 +- .../optim/test_modifier_distillation.py | 92 ++- 43 files changed, 2059 insertions(+), 1721 deletions(-) create mode 100644 src/sparseml/pytorch/utils/sparsification.py create mode 100644 src/sparseml/sparsification/types.py rename src/sparseml/transformers/{utils => }/export.py (68%) rename src/sparseml/transformers/{train => }/question_answering.py (94%) rename src/sparseml/transformers/{train => sparsification}/__init__.py (79%) rename src/sparseml/transformers/{utils => sparsification}/question_answering.py (92%) create mode 100644 src/sparseml/transformers/sparsification/trainer.py rename src/sparseml/transformers/{train => }/text_classification.py (93%) rename src/sparseml/transformers/{train => }/token_classification.py (92%) delete mode 100644 src/sparseml/transformers/train/language_modeling.py delete mode 100644 src/sparseml/transformers/utils/language_modeling.py create mode 100644 src/sparseml/transformers/utils/model.py delete mode 100644 src/sparseml/transformers/utils/text_classification.py delete mode 100644 src/sparseml/transformers/utils/token_classification.py delete mode 100644 src/sparseml/transformers/utils/trainer.py diff --git a/setup.py b/setup.py index 15110848658..f3d2ec580f4 100644 --- a/setup.py +++ b/setup.py @@ -66,22 +66,22 @@ _dev_deps = [ "beautifulsoup4==4.9.3", - "black>=20.8b1", - "flake8>=3.8.3", - "isort>=5.7.0", + "black==21.5b2", + "flake8==3.9.2", + "isort==5.8.0", "m2r2~=0.2.7", "myst-parser~=0.14.0", - "rinohtype>=0.4.2", - "sphinx>=3.4.0", - "sphinx-copybutton>=0.3.0", - "sphinx-markdown-tables>=0.0.15", - "sphinx-multiversion==0.2.4", - "sphinx-pydantic>=0.1.0", - "sphinx-rtd-theme>=0.5.0", + "rinohtype~=0.4.2", + "sphinx~=3.5.0", + "sphinx-copybutton~=0.3.0", + "sphinx-markdown-tables~=0.0.15", + "sphinx-multiversion~=0.2.4", + "sphinx-pydantic~=0.1.0", + "sphinx-rtd-theme~=0.5.0", "wheel>=0.36.2", - "pytest>=6.0.0", - "pytest-mock>=3.6.1", - "flaky>=3.0.0", + "pytest~=6.2.0", + "pytest-mock~=3.6.0", + "flaky~=3.7.0", "sphinx-rtd-theme", ] @@ -112,25 +112,35 @@ def _setup_extras() -> Dict: } -_transformers_entry_point_template = ( - "sparseml.transformers.train.{task}=sparseml.transformers.train.{task}:main" -) - - def _setup_entry_points() -> Dict: - return { + entry_points = { "console_scripts": [ + # sparsification "sparseml.benchmark=sparseml.benchmark.info:_main", "sparseml.framework=sparseml.framework.info:_main", "sparseml.sparsification=sparseml.sparsification.info:_main", - _transformers_entry_point_template.format(task="question_answering"), - _transformers_entry_point_template.format(task="text_classification"), - _transformers_entry_point_template.format(task="token_classification"), - _transformers_entry_point_template.format(task="language_modeling"), - "sparseml.transformers.export_onnx=sparseml.transformers.utils.export:main", ] } + # transformers integration + for task in [ + "question_answering", + "text_classification", + "token_classification", + ]: + entry_points["console_scripts"].extend( + [ + f"sparseml.transformers.{task}=sparseml.transformers.{task}:main", + f"sparseml.transformers.train.{task}=sparseml.transformers.{task}:main", + ] + ) + + entry_points["console_scripts"].append( + "sparseml.transformers.export_onnx=sparseml.transformers.export:main" + ) + + return entry_points + def _setup_long_description() -> Tuple[str, str]: return open("README.md", "r", encoding="utf-8").read(), "text/markdown" diff --git a/src/sparseml/keras/optim/manager.py b/src/sparseml/keras/optim/manager.py index 56af5051318..f1d20761620 100644 --- a/src/sparseml/keras/optim/manager.py +++ b/src/sparseml/keras/optim/manager.py @@ -18,14 +18,14 @@ Also handles loading modifiers from yaml files """ -from typing import List, Union +from typing import Any, Dict, List, Optional, Union from tensorflow import Tensor from sparseml.keras.optim.modifier import Modifier, ScheduledModifier from sparseml.keras.utils.compat import keras from sparseml.keras.utils.logger import KerasLogger -from sparseml.optim import BaseManager, load_recipe_yaml_str +from sparseml.optim import BaseManager, load_recipe_yaml_str, parse_recipe_variables from sparsezoo.objects import Recipe @@ -41,7 +41,7 @@ class ScheduledModifierManager(BaseManager, Modifier): def from_yaml( file_path: Union[str, Recipe], add_modifiers: List[Modifier] = None, - **recipe_variables, + recipe_variables: Optional[Union[Dict[str, Any], str]] = None, ): """ Convenience function used to create the manager of multiple modifiers from a @@ -59,6 +59,7 @@ def from_yaml( with (i.e. num_epochs, init_lr) :return: ScheduledModifierManager() created from the recipe file """ + recipe_variables = parse_recipe_variables(recipe_variables) yaml_str = load_recipe_yaml_str(file_path, **recipe_variables) modifiers = Modifier.load_list(yaml_str) if add_modifiers: diff --git a/src/sparseml/optim/helpers.py b/src/sparseml/optim/helpers.py index 59960959baa..5274fa7de80 100644 --- a/src/sparseml/optim/helpers.py +++ b/src/sparseml/optim/helpers.py @@ -16,8 +16,10 @@ Helper functions for base Modifier and Manger utilities """ +import json import re -from typing import Any, Dict, Tuple, Union +from contextlib import suppress +from typing import Any, Dict, Optional, Tuple, Union import yaml @@ -32,6 +34,7 @@ "rewrite_recipe_yaml_string_with_classes", "update_recipe_variables", "evaluate_recipe_yaml_str_equations", + "parse_recipe_variables", ] @@ -137,6 +140,61 @@ def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str: return pattern.sub(r"!\g", updated_yaml_str) +def parse_recipe_variables( + recipe_variables: Optional[Union[Dict[str, Any], str]] = None +) -> Dict[str, Any]: + """ + Parse input recipe_variables into a dictionary that can be used to overload + variables at the root of a recipe. + Supports dictionaries as well as parsing a string in either json or + csv key=value format + + :param recipe_variables: the recipe_variables string or dictionary to parse + for variables used with overloading recipes + :return: the parsed recipe variables + """ + if not recipe_variables: + return {} + + if isinstance(recipe_variables, Dict): + return recipe_variables + + if not isinstance(recipe_variables, str): + raise ValueError( + f"recipe_args must be a string for parsing, given {recipe_variables}" + ) + + # assume json first, try and parse + with suppress(Exception): + recipe_variables = json.loads(recipe_variables) + return recipe_variables + + # assume csv, and standardize to format key=val + orig_recipe_variables = recipe_variables + recipe_vars_str = recipe_variables.replace(":", "=") + recipe_variables = {} + for arg_val in recipe_vars_str.split(","): + vals = arg_val.split("=") + if len(vals) != 2: + raise ValueError( + "Improper key=val given in csv for recipe variables with value " + f"{arg_val} in {orig_recipe_variables}" + ) + key = vals[0].strip() + if any(char in key for char in ["{", "!", "=", "}"]): + raise ValueError( + "Improper key given in csv for recipe variables with value " + f"{key} in {orig_recipe_variables}" + ) + val = vals[1].strip() + with suppress(Exception): + # check if val should be a number, otherwise fall back on string + val = float(val) + recipe_variables[key] = val + + return recipe_variables + + def update_recipe_variables(recipe_yaml_str: str, variables: Dict[str, Any]) -> str: """ :param recipe_yaml_str: YAML string of a SparseML recipe diff --git a/src/sparseml/optim/manager.py b/src/sparseml/optim/manager.py index 8c205ba864e..28aa17f4dab 100644 --- a/src/sparseml/optim/manager.py +++ b/src/sparseml/optim/manager.py @@ -22,12 +22,8 @@ from functools import cmp_to_key from typing import List -from sparseml.optim.modifier import ( - BaseModifier, - BaseObject, - BaseScheduled, - ModifierProp, -) +from sparseml.optim.modifier import BaseModifier, BaseObject, ModifierProp +from sparseml.sparsification.types import SparsificationTypes from sparseml.utils import clean_path, create_parent_dirs @@ -42,7 +38,7 @@ class BaseManager(BaseObject): :param modifiers: the modifiers to wrap """ - def __init__(self, modifiers: List[BaseScheduled], **kwargs): + def __init__(self, modifiers: List[BaseModifier], **kwargs): super().__init__(**kwargs) # sort modifiers by when they start and end so that later modifiers # can overwrite in a deterministic order such as when initializing @@ -57,44 +53,88 @@ def __del__(self): def __str__(self) -> str: return "\n".join(self.to_string_lines()) + def __eq__(self, compare: object) -> bool: + return str(self) == str(compare) + @ModifierProp(serializable=False) - def modifiers(self) -> List[BaseScheduled]: + def modifiers(self) -> List[BaseModifier]: """ :return: list of all SparseML modifiers in the managed recipe """ return self._modifiers @ModifierProp(serializable=False) - def epoch_modifiers(self) -> List[BaseScheduled]: + def epoch_modifiers(self) -> List[BaseModifier]: """ :return: list of all SparseML modifiers in the managed recipe that modify the epoch range """ - return [mod for mod in self._modifiers if "Epoch" in str(type(mod))] + return [ + mod + for mod in self._modifiers + if SparsificationTypes.epoch in mod.sparsification_types + ] @ModifierProp(serializable=False) - def learning_rate_modifiers(self) -> List[BaseScheduled]: + def learning_rate_modifiers(self) -> List[BaseModifier]: """ :return: list of all SparseML modifiers in the managed recipe that modify the LearningRate schedule """ - return [mod for mod in self._modifiers if "LearningRate" in str(type(mod))] + return [ + mod + for mod in self._modifiers + if SparsificationTypes.learning_rate in mod.sparsification_types + ] @ModifierProp(serializable=False) - def pruning_modifiers(self) -> List[BaseScheduled]: + def pruning_modifiers(self) -> List[BaseModifier]: """ :return: list of all SparseML modifiers in the managed recipe that manage model sparsity """ - return [mod for mod in self._modifiers if "Pruning" in str(type(mod))] + return [ + mod + for mod in self._modifiers + if SparsificationTypes.pruning in mod.sparsification_types + ] @ModifierProp(serializable=False) - def quantization_modifiers(self) -> List[BaseScheduled]: + def quantization_modifiers(self) -> List[BaseModifier]: """ :return: list of all SparseML modifiers in the managed recipe that manage model quantization """ - return [mod for mod in self._modifiers if "Quantization" in str(type(mod))] + return [ + mod + for mod in self._modifiers + if SparsificationTypes.quantization in mod.sparsification_types + ] + + @ModifierProp(serializable=False) + def distillation_modifiers(self) -> List[BaseModifier]: + """ + :return: list of all SparseML modifiers in the managed recipe that manage + Distillation + """ + return [ + mod + for mod in self._modifiers + if SparsificationTypes.distillation in mod.sparsification_types + ] + + @ModifierProp(serializable=False) + def structured_modifiers(self) -> List[BaseModifier]: + """ + :return: list of all SparseML modifiers in the managed recipe that manage + structure changes to a model such as layer pruning, fitler pruning, + and quantization + """ + return [ + mod + for mod in self._modifiers + if SparsificationTypes.structured in mod.sparsification_types + ] @ModifierProp(serializable=False) def min_epochs(self) -> int: @@ -154,7 +194,7 @@ def to_string_lines(self) -> List[str]: return yaml_str_lines - def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]: + def modifiers_to_string_lines(self, modifiers: List[BaseModifier]) -> List[str]: """ :param modifiers: the modifiers to convert into string / yaml representation for within the manage @@ -176,3 +216,18 @@ def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str] yaml_str_lines.append("") return yaml_str_lines + + def qat_active(self, epoch: float) -> bool: + """ + :param epoch: the epoch to check if quantization aware training will be + active during + :return: True if quantization aware training will be active at the start + of or within the given epoch, False otherwise + """ + quant_modifiers = self.quantization_modifiers + + return ( + min(mod.start_epoch for mod in quant_modifiers) < epoch + 1 + if quant_modifiers + else False + ) diff --git a/src/sparseml/optim/modifier.py b/src/sparseml/optim/modifier.py index 1de7a0a0af2..9a622c98e0a 100644 --- a/src/sparseml/optim/modifier.py +++ b/src/sparseml/optim/modifier.py @@ -25,6 +25,7 @@ import yaml from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations +from sparseml.sparsification.types import SparsificationTypes from sparseml.utils import ALL_TOKEN, validate_str_iterable @@ -466,6 +467,13 @@ def __repr__(self): self.props(only_serializable=False, format_repr=True), ) + @ModifierProp(serializable=False) + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [] + @ModifierProp(serializable=True) def log_types(self) -> Union[None, str, List[str]]: """ diff --git a/src/sparseml/pytorch/optim/manager.py b/src/sparseml/pytorch/optim/manager.py index ec3fdc0294f..83d5a82f6b8 100644 --- a/src/sparseml/pytorch/optim/manager.py +++ b/src/sparseml/pytorch/optim/manager.py @@ -25,7 +25,7 @@ from torch.nn import Module from torch.optim.optimizer import Optimizer -from sparseml.optim import BaseManager, load_recipe_yaml_str +from sparseml.optim import BaseManager, load_recipe_yaml_str, parse_recipe_variables from sparseml.pytorch.optim.modifier import Modifier, ScheduledModifier from sparseml.pytorch.utils import BaseLogger, is_parallel_model from sparsezoo.objects import Recipe @@ -250,8 +250,8 @@ class ScheduledModifierManager(BaseManager, Modifier): @staticmethod def from_yaml( file_path: Union[str, Recipe], - add_modifiers: List[Modifier] = None, - **recipe_variables, + add_modifiers: Optional[List[Modifier]] = None, + recipe_variables: Optional[Union[Dict[str, Any], str]] = None, ): """ Convenience function used to create the manager of multiple modifiers from a @@ -266,10 +266,11 @@ def from_yaml( yaml str is also supported in place of a file path. :param add_modifiers: additional modifiers that should be added to the returned manager alongside the ones loaded from the recipe file - :param recipe_variables: additional variable values to override the recipe - with (i.e. num_epochs, init_lr) + :param recipe_variables: additional arguments to override any root variables + in the recipe with (i.e. num_epochs, init_lr) :return: ScheduledModifierManager() created from the recipe file """ + recipe_variables = parse_recipe_variables(recipe_variables) yaml_str = load_recipe_yaml_str(file_path, **recipe_variables) modifiers = Modifier.load_list(yaml_str) @@ -325,6 +326,32 @@ def load_state_dict(self, state_dict: Dict[str, Dict], strict: bool = True): modifiers_index[key].load_state_dict(val) + def apply_structure( + self, + module: Module, + epoch: float = 0.0, + loggers: Optional[List[BaseLogger]] = None, + finalize: bool = False, + **kwargs, + ): + """ + Initialize/apply the modifier for a given model/module at the given epoch + if the modifier affects the structure of the module such as + quantization, layer pruning, or filter pruning. + Calls into initialize(module, epoch, loggers, **kwargs) if structured. + + :param module: the PyTorch model/module to modify + :param epoch: the epoch to apply the modifier at, defaults to 0.0 (start) + :param loggers: Optional list of loggers to log the modification process to + :param finalize: True to invoke finalize after initialize, False otherwise. + Set finalize to True and epoch to math.inf for one shot application. + :param kwargs: Optional kwargs to support specific arguments + for individual modifiers (passed to initialize and finalize). + """ + self._initialize_epoch = epoch + for mod in self._modifiers: + mod.apply_structure(module, epoch, loggers, finalize, **kwargs) + def initialize( self, module: Module, @@ -349,6 +376,10 @@ def initialize( self._initialize_epoch = epoch for mod in self._modifiers: + if mod.initialized: + # check in case modifier was initialized from apply_structure + continue + mod.initialize(module, epoch, loggers, **kwargs) def initialize_loggers(self, loggers: Union[None, List[BaseLogger]]): @@ -371,6 +402,7 @@ def modify( wrap_optim: Any = None, epoch: float = None, allow_parallel_module: bool = True, + **kwargs, ) -> RecipeManagerStepWrapper: """ Modify the given module and optimizer for training aware algorithms such as @@ -393,6 +425,8 @@ def modify( module.module. This is useful so a recipe may reference the base module parameters instead of the wrapped distributed ones. Set to True to not unwrap the distributed module. Default is True + :param kwargs: Key word arguments that are passed to the intialize call + if initilaize has not been called yet :return: A wrapped optimizer object. The wrapped object makes all the original properties for the wrapped object available so it can be used without any additional code changes. @@ -414,7 +448,7 @@ def modify( module = module.module # unwrap parallel module if not self.initialized: - self.initialize(module, epoch) + self.initialize(module, epoch, **kwargs) if wrap_optim is None: wrap_optim = optimizer diff --git a/src/sparseml/pytorch/optim/modifier.py b/src/sparseml/pytorch/optim/modifier.py index b4ad053f441..a03295b8218 100644 --- a/src/sparseml/pytorch/optim/modifier.py +++ b/src/sparseml/pytorch/optim/modifier.py @@ -34,6 +34,7 @@ ModifierYAML, ) from sparseml.pytorch.utils import BaseLogger +from sparseml.sparsification import SparsificationTypes from sparseml.utils import ALL_TOKEN, PYTORCH_FRAMEWORK @@ -174,6 +175,36 @@ def apply( if finalize: self.finalize(module, **kwargs) + def apply_structure( + self, + module: Module, + epoch: float = 0.0, + loggers: Optional[List[BaseLogger]] = None, + finalize: bool = False, + **kwargs, + ): + """ + Initialize/apply the modifier for a given model/module at the given epoch + if the modifier affects the structure of the module such as + quantization, layer pruning, or filter pruning. + Calls into initialize(module, epoch, loggers, **kwargs) if structured. + + :param module: the PyTorch model/module to modify + :param epoch: the epoch to apply the modifier at, defaults to 0.0 (start) + :param loggers: Optional list of loggers to log the modification process to + :param finalize: True to invoke finalize after initialize, False otherwise. + Set finalize to True and epoch to math.inf for one shot application. + :param kwargs: Optional kwargs to support specific arguments + for individual modifiers (passed to initialize and finalize). + """ + if SparsificationTypes.structured not in self.sparsification_types: + return + + self.initialize(module, epoch, loggers, **kwargs) + + if finalize: + self.finalize(module, **kwargs) + def initialize( self, module: Module, diff --git a/src/sparseml/pytorch/optim/modifier_as.py b/src/sparseml/pytorch/optim/modifier_as.py index e3aae052a27..c3f92d607ec 100644 --- a/src/sparseml/pytorch/optim/modifier_as.py +++ b/src/sparseml/pytorch/optim/modifier_as.py @@ -23,9 +23,11 @@ from torch.nn import Module from torch.optim.optimizer import Optimizer +from sparseml.optim import BaseModifier from sparseml.pytorch.optim.modifier import ModifierProp, ScheduledModifier from sparseml.pytorch.optim.sensitivity_as import ASLayerTracker from sparseml.pytorch.utils import BaseLogger, get_layer, get_terminal_layers +from sparseml.sparsification import SparsificationTypes from sparseml.utils import ALL_TOKEN, convert_to_bool, validate_str_iterable @@ -92,6 +94,13 @@ def __init__( self.validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.activation_sparsity] + @ModifierProp() def layers(self) -> Union[str, List[str]]: """ diff --git a/src/sparseml/pytorch/optim/modifier_distillation.py b/src/sparseml/pytorch/optim/modifier_distillation.py index f4affac2086..dd3026405ba 100644 --- a/src/sparseml/pytorch/optim/modifier_distillation.py +++ b/src/sparseml/pytorch/optim/modifier_distillation.py @@ -27,9 +27,10 @@ from torch.nn import Module from torch.optim import Optimizer -from sparseml.optim import ModifierProp -from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier +from sparseml.optim import BaseModifier, ModifierProp +from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledUpdateModifier from sparseml.pytorch.utils import BaseLogger, device_of, tensors_module_forward +from sparseml.sparsification import SparsificationTypes __all__ = [ @@ -41,7 +42,7 @@ @PyTorchModifierYAML() -class DistillationModifier(ScheduledModifier): +class DistillationModifier(ScheduledUpdateModifier): """ Adds a knowledge distillation loss based on a teacher model during the loss_update phase of the SparseML lifecycle. A distillation_teacher @@ -63,9 +64,12 @@ class DistillationModifier(ScheduledModifier): Default is 0.5 :param temperature: temperature applied to teacher and student softmax for distillation - :param distill_output_keys: list of keys to of module outputs to use for - distillation if multiple outputs are present. No or empty list defaults + :param distill_output_keys: list of keys for the module outputs to use for + distillation if multiple outputs are present. None or empty list defaults to using all available outputs + :param teacher_input_keys: list of keys to filter the inputs by before + passing into the teacher. None or empty list defaults to using + all available inputs """ def __init__( @@ -75,6 +79,8 @@ def __init__( hardness: float = 0.5, temperature: float = 2.0, distill_output_keys: List[Any] = None, + teacher_input_keys: List[Any] = None, + update_frequency: float = -1.0, ): super().__init__( start_epoch=start_epoch, @@ -83,14 +89,18 @@ def __init__( ) self._hardness = hardness self._temperature = temperature - self._distill_output_keys = distill_output_keys or [] + self._distill_output_keys = distill_output_keys + self._teacher_input_keys = teacher_input_keys self._teacher = None self._distillation_enabled = False - self._track_student_hook = None - self._student_inputs = None # last forward inputs to student module - self._student_outputs = None # last forward outputs of student module - self._disable_distillation = False + + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.distillation] @ModifierProp() def hardness(self) -> float: @@ -125,29 +135,47 @@ def temperature(self, value: float): self._temperature = value @ModifierProp() - def distill_output_keys(self) -> List[Any]: + def distill_output_keys(self) -> Optional[List[Any]]: """ - :return: list of keys to of module outputs to use for distillation - if multiple outputs are present. No or empty list defaults + :return: list of keys for the module outputs to use for + distillation if multiple outputs are present. None or empty list defaults to using all available outputs """ return self._distill_output_keys @distill_output_keys.setter - def distill_output_keys(self, value: List[Any]): + def distill_output_keys(self, value: Optional[List[Any]]): """ - :params value: list of keys to of module outputs to use for distillation - if multiple outputs are present. No or empty list defaults + :params value: list of keys for the module outputs to use for + distillation if multiple outputs are present. None or empty list defaults to using all available outputs """ self._distill_output_keys = value + @ModifierProp() + def teacher_input_keys(self) -> Optional[List[Any]]: + """ + :return: list of keys to filter the inputs by before + passing into the teacher. None or empty list defaults to using + all available inputs + """ + return self._teacher_input_keys + + @teacher_input_keys.setter + def teacher_input_keys(self, value: Optional[List[Any]]): + """ + :params value: list of keys to filter the inputs by before + passing into the teacher. None or empty list defaults to using + all available inputs + """ + self._teacher_input_keys = value + def initialize( self, module: Module, epoch: float = 0, loggers: Optional[List[BaseLogger]] = None, - distillation_teacher: Module = None, + distillation_teacher: Module = "disable", **kwargs, ): """ @@ -167,35 +195,30 @@ def initialize( """ super().initialize(module, epoch, loggers, **kwargs) - self._disable_distillation = distillation_teacher == "disable" - if distillation_teacher is not None: + if distillation_teacher == "disable": + _LOGGER.warning( + "distillation_teacher set to disable, disabling distillation modifier" + ) + self._distillation_enabled = False + elif distillation_teacher == "self": + self._distillation_enabled = True _LOGGER.info( - "Setting teacher module for distillation to distillation_teacher object" + "distillation_teacher set to self attention, " + "instantiating self distillation at start_epoch" ) + elif callable(distillation_teacher): self._teacher = distillation_teacher - - self._check_distillation_update(module, epoch, steps_per_epoch=0) - - def update( - self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int - ): - """ - If start_pending(), sets a hook for tracking student module inputs and outputs - for distillation - If end_pending(), removes hook for distillation tracking - - :param module: module to modify - :param optimizer: optimizer to modify - :param epoch: current epoch and progress within the current epoch - :param steps_per_epoch: number of steps taken within each epoch - (calculate batch number using this and epoch) - """ - super().update(module, optimizer, epoch, steps_per_epoch) - self._check_distillation_update(module, epoch, steps_per_epoch) + self._distillation_enabled = True + _LOGGER.info("distillation modifier using distillation_teacher object") + else: + raise ValueError( + "unrecognized value for distillation_modifier given of " + f"{distillation_teacher}. " + "To disable set to 'disable' and for self attention set to 'self'" + ) def update_ready(self, epoch: float, steps_per_epoch: int) -> bool: """ - :param epoch: current epoch and progress within the current epoch :param steps_per_epoch: number of steps taken within each epoch (calculate batch number using this and epoch) @@ -204,17 +227,10 @@ def update_ready(self, epoch: float, steps_per_epoch: int) -> bool: if not self._initialized: raise RuntimeError("modifier must be initialized first") - if not self._enabled or self._disable_distillation: - return False - - pending = ( - self.start_pending(epoch, steps_per_epoch) - or self.end_pending(epoch, steps_per_epoch) - or (not self._distillation_enabled and self._is_distillation_epoch(epoch)) + return self._distillation_enabled and super().update_ready( + epoch, steps_per_epoch ) - return pending - def loss_update( self, loss: Tensor, @@ -223,11 +239,11 @@ def loss_update( epoch: float, steps_per_epoch: int, student_outputs: Union[Tensor, Dict, Iterable] = None, - teacher_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None, + student_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None, **kwargs, ) -> Tensor: """ - Updates the bass loss with the distillation loss + Updates the loss with the distillation loss :param loss: The calculated loss tensor :param module: module to modify @@ -241,19 +257,39 @@ def loss_update( loss, module, optimizer, epoch, steps_per_epoch, **kwargs ) - if not self._distillation_enabled or self._disable_distillation: + if not self.update_ready(epoch, steps_per_epoch): return loss - if student_outputs is None or teacher_inputs is None: + if student_outputs is None or student_inputs is None: raise ValueError( "Student outputs and teacher inputs are required for " "distillation loss update" ) + teacher_inputs = ( + student_inputs + if not self._teacher_input_keys + else {key: student_inputs[key] for key in self._teacher_input_keys} + ) + # copy to keep from updating student's inputs + teacher_inputs = deepcopy(teacher_inputs) + + if self._teacher == "self": + _LOGGER.info("Copying current models state for self distillation") + self._teacher = deepcopy(module) + # ensure that teacher model is in eval mode and on correct device self._teacher.eval() - target_device = device_of(teacher_inputs) - self._teacher.to(target_device) + teacher_device = next(self._teacher.parameters()).device + inputs_device = device_of(teacher_inputs) + + if teacher_device != inputs_device: + _LOGGER.info( + f"Teacher device {teacher_device} does not match " + f"inputs device {inputs_device}, moving teacher to correct device" + ) + self._teacher.to(device_of(teacher_inputs)) + with torch.no_grad(): teacher_outputs = tensors_module_forward( teacher_inputs, self._teacher, check_feat_lab_inp=False @@ -261,7 +297,8 @@ def loss_update( if type(student_outputs) != type(teacher_outputs): raise ValueError( - "Student and teacher models must have the same output type" + f"Student output type of {type(student_outputs)} must match " + f"teacher output type of {type(teacher_outputs)}" ) distill_losses = [] @@ -285,9 +322,14 @@ def loss_update( distillation_loss = ((1.0 - self._hardness) * loss) + ( self._hardness * teacher_loss ) - global_step = kwargs.get("global_step") - global_step = epoch * steps_per_epoch if global_step is None else global_step - _log_losses(self.loggers, global_step, loss, teacher_loss, distillation_loss) + _log_losses( + self.loggers, + round(epoch * steps_per_epoch), + loss, + teacher_loss, + distillation_loss, + ) + return distillation_loss def finalize( @@ -306,7 +348,7 @@ def finalize( """ super().finalize(module, reset_loggers, **kwargs) self._teacher = None - self._disable_student_hook() + self._distillation_enabled = False def _calc_distill_loss(self, student_val: Tensor, teacher_val: Tensor) -> Tensor: return ( @@ -318,28 +360,6 @@ def _calc_distill_loss(self, student_val: Tensor, teacher_val: Tensor) -> Tensor * (self._temperature ** 2) ) - def _check_distillation_update( - self, module: Module, epoch: float, steps_per_epoch: int - ): - if self._disable_distillation: - _LOGGER.info("Distillation disabled, using default loss") - return - if self.start_pending(epoch, steps_per_epoch) or ( - not self._distillation_enabled and self._is_distillation_epoch(epoch) - ): - if self._teacher is None: - _LOGGER.info( - "Using self distillation with copy of the module's current state" - ) - self._teacher = deepcopy(module) - self._distillation_enabled = True - - if self.end_pending(epoch, steps_per_epoch): - self._distillation_enabled = False - - def _is_distillation_epoch(self, epoch): - return self.start_epoch <= epoch < self.end_epoch - def _log_losses( loggers: List[BaseLogger], diff --git a/src/sparseml/pytorch/optim/modifier_lr.py b/src/sparseml/pytorch/optim/modifier_lr.py index bab244eaf06..7a603733a6a 100644 --- a/src/sparseml/pytorch/optim/modifier_lr.py +++ b/src/sparseml/pytorch/optim/modifier_lr.py @@ -30,6 +30,7 @@ ) from torch.optim.optimizer import Optimizer +from sparseml.optim import BaseModifier from sparseml.pytorch.optim.modifier import ( ModifierProp, PyTorchModifierYAML, @@ -45,6 +46,7 @@ from sparseml.sparsification import ( SetLearningRateModifier as BaseSetLearningRateModifier, ) +from sparseml.sparsification import SparsificationTypes from sparseml.utils import ALL_TOKEN, convert_to_bool @@ -298,6 +300,13 @@ def __init__( self._last_logged_epoch = None self.validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.learning_rate] + @ModifierProp() def lr_func(self) -> str: """ diff --git a/src/sparseml/pytorch/optim/modifier_params.py b/src/sparseml/pytorch/optim/modifier_params.py index ebbcf32ce29..ad7bb2d9058 100644 --- a/src/sparseml/pytorch/optim/modifier_params.py +++ b/src/sparseml/pytorch/optim/modifier_params.py @@ -23,13 +23,14 @@ from torch.nn import Module, Parameter from torch.optim.optimizer import Optimizer -from sparseml.optim import ModifierProp +from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import ( PyTorchModifierYAML, ScheduledModifier, ScheduledUpdateModifier, ) from sparseml.pytorch.utils import BaseLogger, get_named_layers_and_params_by_regex +from sparseml.sparsification import SparsificationTypes from sparseml.sparsification import ( TrainableParamsModifier as BaseTrainableParamsModifier, ) @@ -196,6 +197,13 @@ def __init__( self._params_strict = params_strict self._module_params = [] # type: List[Parameter] + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.general] + @ModifierProp() def params(self) -> Union[str, List[str]]: """ @@ -384,6 +392,13 @@ def __init__( self.validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.general] + @ModifierProp() def params(self) -> Union[str, List[str]]: """ diff --git a/src/sparseml/pytorch/optim/modifier_pruning.py b/src/sparseml/pytorch/optim/modifier_pruning.py index 601a429ce6e..110978ccd05 100644 --- a/src/sparseml/pytorch/optim/modifier_pruning.py +++ b/src/sparseml/pytorch/optim/modifier_pruning.py @@ -25,6 +25,7 @@ from torch.nn import Module from torch.optim.optimizer import Optimizer +from sparseml.optim import BaseModifier from sparseml.pytorch.nn import Identity from sparseml.pytorch.optim.analyzer_pruning import ModulePruningAnalyzer from sparseml.pytorch.optim.mask_creator_pruning import ( @@ -54,6 +55,7 @@ ConstantPruningModifier as BaseConstantPruningModifier, ) from sparseml.sparsification import GMPruningModifier as BaseGMPruningModifier +from sparseml.sparsification import SparsificationTypes from sparseml.utils import ( ALL_PRUNABLE_TOKEN, ALL_TOKEN, @@ -1248,6 +1250,13 @@ def __init__( self._param_groups = param_groups or [] + @BaseGMPruningModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.pruning, SparsificationTypes.structured] + @ModifierProp() def param_groups(self) -> List[List[str]]: """ @@ -1618,6 +1627,13 @@ def __init__( self._last_logged_epoch = None self._last_logged_layers_replaced = None + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.pruning, SparsificationTypes.structured] + @ModifierProp() def layers(self) -> Union[str, List[str]]: """ diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index 086ec12ef6d..ed69286beb6 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -32,7 +32,7 @@ torch_quantization = None torch_intrinsic = None -from sparseml.optim import ModifierProp +from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.utils import BaseLogger from sparseml.pytorch.utils.quantization import ( @@ -44,6 +44,7 @@ prepare_embeddings_qat, remove_activation_qat_by_layer_name, ) +from sparseml.sparsification import SparsificationTypes __all__ = [ @@ -153,6 +154,13 @@ def __init__( self._validate_params() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.quantization, SparsificationTypes.structured] + @ModifierProp() def submodules(self) -> Union[List[str], None]: """ diff --git a/src/sparseml/pytorch/optim/modifier_regularizer.py b/src/sparseml/pytorch/optim/modifier_regularizer.py index d8739054462..cac58d899a3 100644 --- a/src/sparseml/pytorch/optim/modifier_regularizer.py +++ b/src/sparseml/pytorch/optim/modifier_regularizer.py @@ -22,9 +22,10 @@ from torch.nn import Module from torch.optim import Optimizer -from sparseml.optim import ModifierProp +from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier from sparseml.pytorch.utils import BaseLogger +from sparseml.sparsification import SparsificationTypes from sparseml.utils import ALL_TOKEN, convert_to_bool @@ -83,6 +84,13 @@ def __init__( self._constant_logging = convert_to_bool(constant_logging) self._update_since_last_log = False + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.regularization] + @ModifierProp() def weight_decay(self) -> float: """ diff --git a/src/sparseml/pytorch/optim/modifier_thinning.py b/src/sparseml/pytorch/optim/modifier_thinning.py index 10407eac3c3..5e26836f58a 100644 --- a/src/sparseml/pytorch/optim/modifier_thinning.py +++ b/src/sparseml/pytorch/optim/modifier_thinning.py @@ -25,8 +25,9 @@ from torch.nn import Module, Parameter from torch.optim import Optimizer -from sparseml.optim import ModifierProp +from sparseml.optim import BaseModifier, ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier +from sparseml.sparsification import SparsificationTypes __all__ = [ @@ -95,6 +96,13 @@ def __init__( self._strict = strict self._validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.pruning, SparsificationTypes.structured] + def _validate(self): if self._structure_type not in ["filter", "channel"]: raise ValueError( diff --git a/src/sparseml/pytorch/utils/__init__.py b/src/sparseml/pytorch/utils/__init__.py index a9d6985a204..e9a8eb656c7 100644 --- a/src/sparseml/pytorch/utils/__init__.py +++ b/src/sparseml/pytorch/utils/__init__.py @@ -27,6 +27,7 @@ from .mfac_helpers import * from .model import * from .module import * +from .sparsification import * from .ssd_helpers import * from .yolo_helpers import * diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 7a93832e973..ef052e4f4ca 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -27,15 +27,23 @@ import torch from torch import Tensor from torch.nn import Linear, Module, Parameter -from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader try: + quant_err = None + from torch.nn.qat import Conv2d as QATConv2d + from torch.nn.qat import Conv3d as QATConv3d + from torch.nn.qat import Linear as QATLinear from torch.quantization import QuantWrapper -except Exception: +except Exception as _err: + quant_err = _err QuantWrapper = None + QATLinear = None + QATConv2d = None + QATConv3d = None from sparseml.utils import create_dirs, save_numpy @@ -64,6 +72,7 @@ "get_conv_layers", "get_linear_layers", "get_prunable_layers", + "get_quantizable_layers", "get_named_layers_and_params_by_regex", "any_str_or_regex_matches_param_name", "NamedLayerParam", @@ -751,13 +760,63 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]: :return: a list containing the names and modules of the prunable layers (Linear, ConvNd) """ - layers = [] + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + isinstance(mod, Linear) + or isinstance(mod, _ConvNd) + or (QATLinear and isinstance(mod, QATLinear)) + or (QATConv2d and isinstance(mod, QATConv2d)) + or (QATConv3d and isinstance(mod, QATConv3d)) + ) + ] + + +def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: + """ + :param module: the module to get the quantizable layers from + :return: a list containing the names and modules of the quantizable layers + (Linear, Conv2d, Conv3d) + """ + if QATLinear is None: + raise ImportError( + "PyTorch version is not setup for Quantization. " + "Please install a QAT compatible version of PyTorch" + ) + + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + isinstance(mod, Linear) + or isinstance(mod, Conv2d) + or isinstance(mod, Conv3d) + ) + ] - for name, mod in module.named_modules(): - if isinstance(mod, Linear) or isinstance(mod, _ConvNd): - layers.append((name, mod)) - return layers +def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]: + """ + :param module: the module to get the quantized layers from + :return: a list containing the names and modules of the quantized layers + (Linear, Conv2d, Conv3d) + """ + if QATLinear is None: + raise ImportError( + "PyTorch version is not setup for Quantization. " + "Please install a QAT compatible version of PyTorch" + ) + + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + (QATLinear and isinstance(mod, QATLinear)) + or (QATConv2d and isinstance(mod, QATConv2d)) + or (QATConv3d and isinstance(mod, QATConv3d)) + ) + ] def get_layer_param(param: str, layer: str, module: Module) -> Parameter: diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py new file mode 100644 index 00000000000..d88a548ad12 --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for retrieving information related to model sparsification +""" + +import json +from typing import Dict + +import torch +from torch.nn import Module + +from sparseml.pytorch.utils.helpers import ( + get_prunable_layers, + get_quantizable_layers, + get_quantized_layers, + tensor_sparsity, +) + + +__all__ = ["ModuleSparsificationInfo"] + + +class ModuleSparsificationInfo: + """ + Helper class for providing information related to torch Module parameters + and the amount of sparsification applied. Includes information for pruning + and quantization + + :param module: torch Module to analyze + """ + + def __init__(self, module: Module): + self.module = module + self.trainable_params = list( + filter(lambda param: param.requires_grad, self.module.parameters()) + ) + + def __str__(self): + return json.dumps( + { + "params_summary": { + "total": self.params_total, + "sparse": self.params_sparse, + "sparsity_percent": self.params_sparse_percent, + "prunable": self.params_prunable_total, + "prunable_sparse": self.params_prunable_sparse, + "prunable_sparsity_percent": self.params_prunable_sparse_percent, + "quantizable": self.params_quantizable, + "quantized": self.params_quantized, + "quantized_percent": self.params_quantized_percent, + }, + "params_info": self.params_info, + } + ) + + @property + def params_total(self) -> int: + """ + :return: total number of trainable parameters in the model + """ + return sum(torch.numel(param) for param in self.trainable_params) + + @property + def params_sparse(self) -> int: + """ + :return: total number of sparse (0) trainable parameters in the model + """ + return sum( + round(tensor_sparsity(param).item() * torch.numel(param)) + for param in self.trainable_params + ) + + @property + def params_sparse_percent(self) -> float: + """ + :return: percent of sparsified parameters in the entire model + """ + return self.params_sparse / float(self.params_total) * 100 + + @property + def params_prunable_total(self) -> int: + """ + :return: total number of parameters across prunable layers + """ + return sum( + torch.numel(layer.weight) + for (name, layer) in get_prunable_layers(self.module) + ) + + @property + def params_prunable_sparse(self) -> int: + """ + :return: total number of sparse (0) parameters across prunable lauyers + """ + return sum( + round(tensor_sparsity(layer.weight).item() * torch.numel(layer.weight)) + for (name, layer) in get_prunable_layers(self.module) + ) + + @property + def params_prunable_sparse_percent(self) -> float: + """ + :return: percent of prunable parameters that have been pruned + """ + return self.params_prunable_sparse / float(self.params_prunable_total) * 100 + + @property + def params_quantizable(self) -> int: + """ + :return: number of parameters that are included in quantizable layers + """ + return sum( + torch.numel(layer.weight) + + ( + torch.numel(layer.bias) + if hasattr(layer, "bias") and layer.bias is not None + else 0 + ) + for (name, layer) in get_quantizable_layers(self.module) + ) + + @property + def params_quantized(self) -> int: + """ + :return: number of parameters across quantized layers + """ + return sum( + torch.numel(layer.weight) + + ( + torch.numel(layer.bias) + if hasattr(layer, "bias") and layer.bias is not None + else 0 + ) + for (name, layer) in get_quantized_layers(self.module) + ) + + @property + def params_quantized_percent(self) -> float: + """ + :return: percentage of parameters that have been quantized + """ + return self.params_quantized / float(self.params_quantizable) * 100 + + @property + def params_info(self) -> Dict[str, Dict]: + """ + :return: dict of parameter name to its sparsification information + """ + return { + f"{name}.weight": { + "numel": torch.numel(layer.weight), + "sparsity": tensor_sparsity(layer.weight).item(), + "quantized": hasattr(layer, "weight_fake_quant"), + } + for (name, layer) in get_prunable_layers(self.module) + } diff --git a/src/sparseml/sparsification/__init__.py b/src/sparseml/sparsification/__init__.py index a5905cb9331..488bceb07c2 100644 --- a/src/sparseml/sparsification/__init__.py +++ b/src/sparseml/sparsification/__init__.py @@ -29,3 +29,4 @@ from .oracle import * from .recipe_builder import * from .recipe_editor import * +from .types import * diff --git a/src/sparseml/sparsification/modifier_epoch.py b/src/sparseml/sparsification/modifier_epoch.py index 5dda02f8477..866f94a3bbd 100644 --- a/src/sparseml/sparsification/modifier_epoch.py +++ b/src/sparseml/sparsification/modifier_epoch.py @@ -17,7 +17,10 @@ model """ +from typing import List + from sparseml.optim.modifier import BaseModifier, BaseScheduled +from sparseml.sparsification.types import SparsificationTypes __all__ = ["EpochRangeModifier"] @@ -49,3 +52,10 @@ def __init__( super(EpochRangeModifier, self).__init__( start_epoch=start_epoch, end_epoch=end_epoch, **kwargs ) + + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.general, SparsificationTypes.epoch] diff --git a/src/sparseml/sparsification/modifier_lr.py b/src/sparseml/sparsification/modifier_lr.py index f65d8eadc16..40a04bfd2db 100644 --- a/src/sparseml/sparsification/modifier_lr.py +++ b/src/sparseml/sparsification/modifier_lr.py @@ -23,6 +23,7 @@ BaseUpdate, ModifierProp, ) +from sparseml.sparsification.types import SparsificationTypes from sparseml.utils import ALL_TOKEN @@ -68,6 +69,13 @@ def __init__( self._learning_rate = learning_rate self.validate_learning_rate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.learning_rate] + @ModifierProp() def learning_rate(self) -> float: """ @@ -149,6 +157,13 @@ def __init__( self._init_lr = init_lr self.validate_lr_info() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.learning_rate] + @ModifierProp() def lr_class(self) -> str: """ diff --git a/src/sparseml/sparsification/modifier_params.py b/src/sparseml/sparsification/modifier_params.py index 49b8cab2b20..85ca987c328 100644 --- a/src/sparseml/sparsification/modifier_params.py +++ b/src/sparseml/sparsification/modifier_params.py @@ -20,6 +20,7 @@ from typing import List, Union from sparseml.optim.modifier import BaseModifier, BaseScheduled, ModifierProp +from sparseml.sparsification.types import SparsificationTypes from sparseml.utils import convert_to_bool, validate_str_iterable @@ -74,6 +75,13 @@ def __init__( self._vars_to_trainable_orig = {} self.validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.general] + @ModifierProp() def params(self) -> Union[str, List[str]]: """ diff --git a/src/sparseml/sparsification/modifier_pruning.py b/src/sparseml/sparsification/modifier_pruning.py index 6a4dd0a16f8..c1a13c7c05d 100644 --- a/src/sparseml/sparsification/modifier_pruning.py +++ b/src/sparseml/sparsification/modifier_pruning.py @@ -26,6 +26,7 @@ BaseUpdate, ModifierProp, ) +from sparseml.sparsification.types import SparsificationTypes from sparseml.utils import ALL_TOKEN, convert_to_bool, validate_str_iterable @@ -74,6 +75,13 @@ def __init__( params, "{} for params".format(self.__class__.__name__) ) # type: List[str] + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.pruning] + @ModifierProp() def params(self) -> Union[str, List[str]]: """ @@ -175,6 +183,13 @@ def __init__( self.validate() + @BaseModifier.sparsification_types.getter + def sparsification_types(self) -> List[SparsificationTypes]: + """ + :return: the sparsification types this modifier instance will apply + """ + return [SparsificationTypes.pruning] + @ModifierProp() def params(self) -> Union[str, List[str]]: """ diff --git a/src/sparseml/sparsification/types.py b/src/sparseml/sparsification/types.py new file mode 100644 index 00000000000..78d2d153341 --- /dev/null +++ b/src/sparseml/sparsification/types.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Base classes and implementations for types of sparsification algorithms. +""" + +from enum import Enum + + +__all__ = ["SparsificationTypes"] + + +class SparsificationTypes(Enum): + """ + SparsificationTypes to give context to what a modifier or other parts of the + system are and can do when applied to a model for sparsification. + """ + + general = "general" + epoch = "epoch" + learning_rate = "learning_rate" + activation_sparsity = "activation_sparsity" + pruning = "pruning" + quantization = "quantization" + distillation = "distillation" + regularization = "regularization" + structured = "structured" diff --git a/src/sparseml/tensorflow_v1/optim/manager.py b/src/sparseml/tensorflow_v1/optim/manager.py index fa2d00b7b4e..98f098e8a42 100644 --- a/src/sparseml/tensorflow_v1/optim/manager.py +++ b/src/sparseml/tensorflow_v1/optim/manager.py @@ -19,9 +19,14 @@ """ import itertools -from typing import Any, Callable, Dict, List, Tuple, Union - -from sparseml.optim import BaseManager, BaseScheduled, load_recipe_yaml_str +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from sparseml.optim import ( + BaseManager, + BaseScheduled, + load_recipe_yaml_str, + parse_recipe_variables, +) from sparseml.tensorflow_v1.optim.modifier import NM_RECAL, Modifier, ScheduledModifier from sparseml.tensorflow_v1.utils import tf_compat from sparsezoo.objects import Recipe @@ -77,7 +82,7 @@ class ScheduledModifierManager(BaseManager, Modifier): def from_yaml( file_path: Union[str, Recipe], add_modifiers: List[Modifier] = None, - **recipe_variables, + recipe_variables: Optional[Union[Dict[str, Any], str]] = None, ): """ Convenience function used to create the manager of multiple modifiers from a @@ -95,6 +100,7 @@ def from_yaml( with (i.e. num_epochs, init_lr) :return: ScheduledModifierManager() created from the recipe file """ + recipe_variables = parse_recipe_variables(recipe_variables) yaml_str = load_recipe_yaml_str(file_path, **recipe_variables) modifiers = Modifier.load_list(yaml_str) if add_modifiers: diff --git a/src/sparseml/transformers/__init__.py b/src/sparseml/transformers/__init__.py index 7d69924ca71..6a887388799 100644 --- a/src/sparseml/transformers/__init__.py +++ b/src/sparseml/transformers/__init__.py @@ -112,6 +112,4 @@ def _check_transformers_install(): _check_transformers_install() -from .utils.export import * -from .utils.helpers import * -from .utils.trainer import * +from .export import * diff --git a/src/sparseml/transformers/utils/export.py b/src/sparseml/transformers/export.py similarity index 68% rename from src/sparseml/transformers/utils/export.py rename to src/sparseml/transformers/export.py index 1c6cea2befb..2bc63ef6873 100644 --- a/src/sparseml/transformers/utils/export.py +++ b/src/sparseml/transformers/export.py @@ -56,46 +56,60 @@ import argparse import logging +import math import os -from typing import Optional - -import torch -from transformers import ( - AutoConfig, - AutoModelForMaskedLM, - AutoModelForQuestionAnswering, - AutoModelForSequenceClassification, - AutoModelForTokenClassification, - AutoTokenizer, -) -from transformers.file_utils import WEIGHTS_NAME +from typing import Any, Optional + +from torch.nn import Module +from transformers import AutoConfig, AutoTokenizer from transformers.tokenization_utils_base import PaddingStrategy -from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import export_onnx -from sparseml.transformers.utils.helpers import RECIPE_NAME +from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.utils import SparseAutoModel __all__ = ["export_transformer_to_onnx"] _LOGGER = logging.getLogger(__name__) -_TASK_TO_CONSTRUCTOR = { - # language modeling - "mlm": AutoModelForMaskedLM, - "masked-language-modeling": AutoModelForMaskedLM, - # question answering - "qa": AutoModelForQuestionAnswering, - "question-answering": AutoModelForQuestionAnswering, - # GLUE - "glue": AutoModelForSequenceClassification, - "sequence-classification": AutoModelForSequenceClassification, - "sentiment-analysis": AutoModelForSequenceClassification, - "text-classification": AutoModelForSequenceClassification, - # token classification - "ner": AutoModelForTokenClassification, - "token-classification": AutoModelForTokenClassification, -} + + +def _load_task_model(task: str, model_path: str, config: Any) -> Module: + if task == "masked-language-modeling" or task == "mlm": + return SparseAutoModel.masked_language_modeling_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + ) + + if task == "question-answering" or task == "qa": + return SparseAutoModel.question_answering_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + ) + + if ( + task == "sequence-classification" + or task == "glue" + or task == "sentiment-analysis" + or task == "text-classification" + ): + return SparseAutoModel.text_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + ) + + if task == "token-classification" or task == "ner": + return SparseAutoModel.token_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + ) + + raise ValueError(f"unrecognized task given of {task}") def export_transformer_to_onnx( @@ -123,50 +137,57 @@ def export_transformer_to_onnx( pipeline, it will look only for 'model.onnx' :return: path to the exported ONNX file """ - task = "-".join(task.lower().split("_")) - if task not in _TASK_TO_CONSTRUCTOR: - raise ValueError( - f"task {task} unsupported for export_transformer_to_onnx. Supported " - f"tasks include {list(_TASK_TO_CONSTRUCTOR.keys())}" - ) - auto_model_constructor = _TASK_TO_CONSTRUCTOR[task] + task = task.replace("_", "-").replace(" ", "-") - if not os.path.isdir(model_path): + if not os.path.exists(model_path) or not os.path.isdir(model_path): raise ValueError( "model_path must be a directory that contains the trained transformer " - f"files. {model_path} is not a directory" + f"files. {model_path} is not a directory or does not exist" ) - # load config and tokenizer + _LOGGER.info(f"Attempting onnx export for model at {model_path} for task {task}") config_args = {"finetuning_task": finetuning_task} if finetuning_task else {} - config = AutoConfig.from_pretrained(model_path, **config_args) + config = AutoConfig.from_pretrained( + model_path, + **config_args, + ) tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length ) - - # load model - model = auto_model_constructor.from_pretrained( - model_path, - from_tf=False, - config=config, + model = _load_task_model(task, model_path, config) + _LOGGER.info(f"loaded model, config, and tokenizer from {model_path}") + + trainer = Trainer( + model=model, + model_state_path=model_path, + recipe=None, + recipe_args=None, + teacher=None, ) + applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) - # apply recipe if exists before loading model weights - recipe_path = os.path.join(model_path, RECIPE_NAME) - if os.path.isfile(recipe_path): - ScheduledModifierManager.from_yaml(recipe_path).apply(model) + if not applied: + _LOGGER.warning( + f"No recipes were applied for {model_path}, " + "check to make sure recipe(s) are stored in the model_path" + ) else: - _LOGGER.warning(f"recipe not found under {recipe_path}") - - # load weights - load_kwargs = {} if torch.cuda.is_available() else {"map_location": "cpu"} - state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME), **load_kwargs) - model.load_state_dict(state_dict) + trainer.finalize_manager() + total_recipes = (1 if trainer.manager else 0) + len(trainer.arch_managers) + _LOGGER.info(f"Applied {total_recipes} total recipes the model at {model_path}") # create fake model input inputs = tokenizer( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] + inputs_shapes = { + key: ( + f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " + f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}" + ) + for key, val in inputs.items() + } + _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") # run export onnx_file_path = os.path.join(model_path, onnx_file_name) @@ -176,6 +197,7 @@ def export_transformer_to_onnx( onnx_file_path, convert_qat=convert_qat, ) + _LOGGER.info(f"ONNX exported to {onnx_file_path}") return onnx_file_path diff --git a/src/sparseml/transformers/train/question_answering.py b/src/sparseml/transformers/question_answering.py similarity index 94% rename from src/sparseml/transformers/train/question_answering.py rename to src/sparseml/transformers/question_answering.py index 3109aff900d..45c6ab14fb7 100644 --- a/src/sparseml/transformers/train/question_answering.py +++ b/src/sparseml/transformers/question_answering.py @@ -30,12 +30,10 @@ from dataclasses import dataclass, field from typing import Optional -import numpy import transformers from datasets import load_dataset, load_metric from transformers import ( AutoConfig, - AutoModelForQuestionAnswering, AutoTokenizer, DataCollatorWithPadding, EvalPrediction, @@ -48,12 +46,11 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version -from sparseml.transformers.utils import ( - SparseMLQATrainer, - load_recipe, +from sparseml.transformers.sparsification import ( + QuestionAnsweringTrainer, postprocess_qa_predictions, - preprocess_state_dict, ) +from sparseml.transformers.utils import SparseAutoModel # Will error if the minimal version of Transformers is not installed @@ -431,35 +428,21 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - - # Load and preprocess the state dict if the model existed (in this case we - # continue to train or evaluate the model). The preprocessing step is to - # restore names of parameters changed by QAT process. - state_dict = preprocess_state_dict(model_args.model_name_or_path) - - model = AutoModelForQuestionAnswering.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - state_dict=state_dict, + model, teacher = SparseAutoModel.question_answering_from_pretrained_distil( + model_name_or_path=model_args.model_name_or_path, + model_kwargs={ + "config": config, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + }, + teacher_name_or_path=model_args.distill_teacher, + teacher_kwargs={ + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + }, ) - teacher_model = None - if model_args.distill_teacher is not None: - teacher_model = AutoModelForQuestionAnswering.from_pretrained( - model_args.distill_teacher, - from_tf=bool(".ckpt" in model_args.distill_teacher), - cache_dir=model_args.cache_dir, - ) - teacher_model_parameters = filter( - lambda p: p.requires_grad, teacher_model.parameters() - ) - params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) - _LOGGER.info("Teacher Model has %s parameters", params) - # Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( @@ -738,17 +721,13 @@ def post_processing_function(examples, features, predictions, stage="eval"): def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) - # Load possible existing recipe and new one passed in through command argument - existing_recipe = load_recipe(model_args.model_name_or_path) - new_recipe = data_args.recipe - # Initialize our Trainer - trainer = SparseMLQATrainer( - model_args.model_name_or_path, - recipe=new_recipe, - checkpoint_recipes=[existing_recipe], - teacher=teacher_model, + trainer = QuestionAnsweringTrainer( model=model, + model_state_path=model_args.model_name_or_path, + recipe=data_args.recipe, + recipe_args=data_args.recipe_args, + teacher=teacher, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, @@ -757,14 +736,8 @@ def compute_metrics(p: EvalPrediction): data_collator=data_collator, post_process_function=post_processing_function, compute_metrics=compute_metrics, - recipe_args=data_args.recipe_args, ) - # Apply recipes to the model. This is necessary given that - # sparsification methods such as QAT modified the model graph with their own - # learnable parameters. They are also restored/loaded to the model - trainer.apply_recipes() - # Training if training_args.do_train: checkpoint = None diff --git a/src/sparseml/transformers/train/__init__.py b/src/sparseml/transformers/sparsification/__init__.py similarity index 79% rename from src/sparseml/transformers/train/__init__.py rename to src/sparseml/transformers/sparsification/__init__.py index ff57870edbe..61a91e00a04 100644 --- a/src/sparseml/transformers/train/__init__.py +++ b/src/sparseml/transformers/sparsification/__init__.py @@ -13,7 +13,11 @@ # limitations under the License. """ -Scripts for training various transformers NLP tasks +Objects, classes, and methods for applying sparsification algorithms to +Hugging Face transformers flows """ # flake8: noqa + +from .question_answering import * +from .trainer import * diff --git a/src/sparseml/transformers/utils/question_answering.py b/src/sparseml/transformers/sparsification/question_answering.py similarity index 92% rename from src/sparseml/transformers/utils/question_answering.py rename to src/sparseml/transformers/sparsification/question_answering.py index 3921045af34..c13a0037978 100644 --- a/src/sparseml/transformers/utils/question_answering.py +++ b/src/sparseml/transformers/sparsification/question_answering.py @@ -24,15 +24,15 @@ import json import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -import torch +from torch.nn import Module from tqdm.auto import tqdm from transformers import Trainer, is_torch_tpu_available from transformers.trainer_utils import PredictionOutput -from sparseml.transformers.utils.trainer import SparseMLTrainer +from sparseml.transformers.sparsification.trainer import TrainerInterface if is_torch_tpu_available(): @@ -41,7 +41,7 @@ __all__ = [ - "SparseMLQATrainer", + "QuestionAnsweringTrainer", "postprocess_qa_predictions", ] @@ -144,41 +144,38 @@ def predict(self, predict_dataset, predict_examples, ignore_keys=None): ) -class SparseMLQATrainer(SparseMLTrainer, _QuestionAnsweringTrainer): +class QuestionAnsweringTrainer(TrainerInterface, _QuestionAnsweringTrainer): """ Trainer for running sparsification recipes with Question Answering training - :param model_name_or_path: path to model directory to be trained - :param recipe: path to recipe for model sparsification - :param checkpoint_recipes: list of paths to recipes used to train the - starting checkpoint for this training run. Will be applied to the model - on call to `apply_recipes` so that model state can be reproduced for - weight loading - :param teacher: teacher model for distillation. Default is None - :param recipe_args: Dictionary of recipe variables to override or json - loadable string of those args. Default is None - :param args: arguments passed into parent class + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation :param kwargs: key word arguments passed to the parent class """ def __init__( self, - model_name_or_path: str, + model: Module, + model_state_path: str, recipe: str, - checkpoint_recipes: Union[str, List[str]] = None, - teacher: Optional[torch.nn.Module] = None, - recipe_args: Union[Dict[str, Any], str] = None, - *args, + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Module] = None, **kwargs, ): super().__init__( - model_name_or_path=model_name_or_path, + model=model, + model_state_path=model_state_path, recipe=recipe, - checkpoint_recipes=checkpoint_recipes, - teacher=teacher, recipe_args=recipe_args, - teacher_input_keys=["input_ids", "token_type_ids", "attention_mask"], - *args, + teacher=teacher, **kwargs, ) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py new file mode 100644 index 00000000000..d912c9808c5 --- /dev/null +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -0,0 +1,671 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SparseML transformers trainer classes and interfaces to be plugged in with existing +or similiar HF trainer flows +""" + + +import glob +import logging +import math +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import Module +from transformers import Trainer as TransformersTrainer +from transformers import TrainerCallback, TrainerControl, TrainingArguments +from transformers.file_utils import WEIGHTS_NAME +from transformers.trainer_callback import TrainerState +from transformers.trainer_utils import get_last_checkpoint + +from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.utils import ModuleSparsificationInfo, WANDBLogger +from sparseml.transformers.utils import SparseAutoModel +from sparseml.transformers.utils.helpers import RECIPE_REGEX, RECIPE_TEMPLATE + + +__all__ = [ + "RecipeManagerTrainerInterface", + "TrainerInterface", + "Trainer", + "DisableHalfPrecisionCallback", +] + + +_LOGGER = logging.getLogger(__name__) +TRAINER_STATE_NAME = "trainer_state.json" + + +class RecipeManagerTrainerInterface: + """ + Training base interface for running sparsification recipes with transformers flows. + Defines it's own lifecycle that is compatible with transformers flows. + Can additionally be used outside of transformers flows provided + they match reasonably closely. + + Should be instantiated with multi-inheretance with a custom trainer class. + RecipeManagerTrainerInterface must be provided + before Trainer for proper class dependency. + i.e. class MyCustomTrainer(RecipeManagerTrainerInterface, Trainer) + + Expected lifecycle: + 1. apply_manager + 2. create_optimizer (only for training) + 3. create_scheduler (only for training) + 4. compute_loss (only for training, called before each step) + 5. save_model (only for training) + 6. finalize_manager + + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation + :param kwargs: key word arguments passed to the parent class + """ + + def __init__( + self, + model: Module, + model_state_path: str, + recipe: Optional[str], + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Union[Module, str]] = None, + **kwargs, + ): + # instantiate necessary state, like managers, so we can override args + self.model = model + self.model_state_path = str(model_state_path) + self.recipe = recipe + self.recipe_args = recipe_args + self.teacher = teacher + + report_to = ( + "" + if "args" not in kwargs + or not kwargs["args"] + or not kwargs["args"].report_to + else kwargs["args"].report_to + ) + self.manager_loggers = [WANDBLogger()] if "wandb" in report_to else None + + # remove arch_managers once recipe stages are supported + self.manager, self.arch_managers = self._setup_manager(kwargs) + self.manager_applied = False + self.manager_initialized = False + self.manager_finalized = False + self.manager_steps_per_epoch = 0 + + super().__init__(model=model, **kwargs) + self.criterion = torch.nn.CrossEntropyLoss() + self.callback_disable_fp16 = DisableHalfPrecisionCallback(self) + self.callback_handler.add_callback(self.callback_disable_fp16) + + def apply_manager(self, epoch: float, checkpoint: Optional[str]) -> bool: + """ + Apply the recipe(s) to the model and training/validation process. + + :param epoch: the training epoch to apply the recipe(s) at. + If loading after training, set epoch=math.inf + :param checkpoint: the optional checkpoint to use to reload model state + from after the model's architecture has been modified. + If not supplied, falls back to self.model_state_path + :return: True if recipes were applied, Flase otherwise + """ + if (not self.arch_managers and self.manager is None) or self.manager_applied: + return False + + orig_state_dict = self.model.state_dict() + + # apply architecture changes to prep for reload of weights to handle + # things like layer dropping and quantization which changes param names + if self.arch_managers: + for arch_manager in self.arch_managers: + arch_manager.apply_structure(self.model, epoch=math.inf, finalize=True) + _LOGGER.info( + f"Applied structure from {len(self.arch_managers)} " + "SparseML recipes to model and finalized " + "(recipes saved with model_path)" + ) + + if self.manager is not None: + self.manager.apply_structure(self.model, epoch=epoch) + _LOGGER.info( + "Applied structure from SparseML recipe argument to model at " + f"epoch {epoch}" + ) + + # reload the state dict for the model now that architecture matches expected + load_path = checkpoint or self.model_state_path + self._reload_model_state(load_path, orig_state_dict) + self.manager_applied = True + _LOGGER.info( + "Reloaded model state after SparseML recipe structure modifications " + f"from {load_path}" + ) + + return True + + def finalize_manager(self) -> bool: + """ + Finalize the current recipes to wrap up any held state. + + :return: True if recipes were finalized, False otherwise + """ + if ( + self.manager is None + or not self.manager_initialized + or self.manager_finalized + ): + return False + + self.manager.finalize(self.model) + self.manager_finalized = True + _LOGGER.info("Finalized SparseML recipe argument applied to the model") + + return True + + def create_optimizer(self): + """ + Override the optimizer to apply and update the recipe while training. + create_optimizer must exist in the parent class and should set + self.optimizer to the optimizer state and optionally set self.scaler + if using amp. + """ + self._check_super_defined("create_optimizer") + super().create_optimizer() + + if not self.manager: + return + + total_batch_size = ( + self.args.per_device_train_batch_size + * self.args._n_gpu + * self.args.gradient_accumulation_steps + ) + self.manager_steps_per_epoch = math.ceil( + len(self.train_dataset) / total_batch_size + ) + + if hasattr(self, "scaler"): + wrap_optim_key = "scaler" + self.scaler = self.manager.modify( + self.model, + self.optimizer, + steps_per_epoch=self.manager_steps_per_epoch, + allow_parallel_module=False, + wrap_optim=self.scaler, + loggers=self.manager_loggers, + distillation_teacher=self.teacher, + ) + else: + wrap_optim_key = "optimizer" + self.optimizer = ScheduledOptimizer( + self.optimizer, + self.model, + self.manager, + steps_per_epoch=self.manager_steps_per_epoch, + loggers=self.manager_loggers, + ) + if not self.manager.initialized: + self.manager.initialize( + self.model, + loggers=self.manager_loggers, + distillation_teacher=self.teacher, + ) + self.manager_initialized = True + _LOGGER.info( + f"Modified the {wrap_optim_key} from the recipe for training with " + f"total_batch_size: {total_batch_size} and " + f"steps_per_epoch: {self.manager_steps_per_epoch}" + ) + + def create_scheduler(self, num_training_steps: int): + """ + Create an LR scheduler to work with the applied recipes. + If the recipe specifies LR modifiers, then will set lr_scheduler + to a placeholder lr scheduler. + Expects create_scheduler to be defined in the super class. + Additionally expects self.lr_scheduler argument to be available. + + :param num_training_steps: the total number of training steps + """ + self._check_super_defined("create_scheduler") + + if ( + self.lr_scheduler is not None + or self.manager is None + or not self.manager.learning_rate_modifiers + ): + super().create_scheduler(num_training_steps) + return + + # allow SparseML to manage LR and set a dummy scheduler + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lambda _: 1.0, -1 + ) + _LOGGER.warning("Overrode the lr_scheduler from SparseML recipe") + + def compute_loss( + self, model: Module, inputs: Dict[str, Any], return_outputs: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]: + """ + Override for the compute_loss to factor in distillation modifiers. + If distillation modifiers are present in the recipe, then will + add the distillation loss to the normal loss function. + Expects compute_loss to be defined in the suepr class. + + :param model: the model to compute the loss for + :param inputs: the inputs to pass through the model for calculating the loss + :param return_outputs: True to return the outputs with the loss, + False otherwise + :return: the resulting loss if not return_outputs, otherwise a tuple + containing the loss and the model's outputs + """ + self._check_super_defined("compute_loss") + + if ( + self.manager is None + or not self.manager.initialized + or not self.manager.enabled + or not self.manager.distillation_modifiers + ): + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + student_outputs = model(**inputs) + loss = student_outputs["loss"] + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + self.manager_steps_per_epoch, + student_outputs=student_outputs, + student_inputs=inputs, + ) + + return (loss, student_outputs) if return_outputs else loss + + def save_model(self, output_dir: Optional[str] = None): + """ + Override of the save_model function and expects it to exist in the parent. + Calls into super() to save the model and additionally saves any recipes + that were used with the model within the model folder. + + :param output_dir: the path to save the recipes into + """ + """ + Save model during or after training. Modifiers that change the model + architecture will also be saved + """ + self._check_super_defined("save_model") + super().save_model(output_dir=output_dir) + + if self.manager is None: + return + + if output_dir is None: + output_dir = self.args.output_dir + + index = len(self.arch_managers) + recipe_path = os.path.join( + output_dir, RECIPE_TEMPLATE.format(f"_{index:02d}" if index > 0 else "") + ) + self.manager.save(recipe_path) + _LOGGER.info(f"Saved SparseML recipe with model state to {recipe_path}") + + def log_model_sparsification(self): + """ + Log the current model sparsification info including pruned and quantized states + """ + sparsification_info = ModuleSparsificationInfo(self.model) + + _LOGGER.info( + f"Sparsification info for {self.model_state_path}: " + f"{sparsification_info.params_total} total params. " + f"Of those there are {sparsification_info.params_prunable_total} prunable " + f"params which have {sparsification_info.params_prunable_sparse_percent} " + "avg sparsity." + ) + model_type = ( + "sparse" + if sparsification_info.params_prunable_sparse_percent > 5 + else "dense" + ) + _LOGGER.info( + f"{model_type} model detected, " + f"all sparsification info: {sparsification_info}" + ) + + def _check_super_defined(self, func: str): + if not hasattr(super(), func): + raise NotImplementedError( + f"The super class for SparseMLTrainer must define a {func} function" + ) + + def _setup_manager( + self, kwargs + ) -> Tuple[Optional[ScheduledModifierManager], List[ScheduledModifierManager]]: + manager = None + arch_managers = [] + + if self.recipe is not None: + manager = ScheduledModifierManager.from_yaml( + self.recipe, recipe_variables=self.recipe_args + ) + _LOGGER.info( + "Loaded SparseML recipe variable into manager for recipe: " + f"{self.recipe} and recipe_variables: {self.recipe_args}" + ) + + arch_recipe_paths = glob.glob(os.path.join(self.model_state_path, RECIPE_REGEX)) + if arch_recipe_paths: + arch_managers = [ + ScheduledModifierManager.from_yaml(path) for path in arch_recipe_paths + ] + _LOGGER.info( + f"Loaded SparseML {len(arch_recipe_paths)} recipes into architecture " + f"managers from {arch_recipe_paths}" + ) + + if manager is not None and manager in arch_managers: + # new recipe and the one stored with model are the same, + # keep manager and remove from arch_managers to keep from applying twice. + # remove this logic once recipe stages land + arch_managers.remove(manager) + _LOGGER.info( + "Removed duplicate SparseML recipe from arch_managers that matched " + "the recipe variable to prevent double application" + ) + + if ( + manager is not None + and manager.max_epochs + and "args" in kwargs + and (hasattr(kwargs["args"], "num_train_epochs")) + ): + _LOGGER.warning( + f"Overriding num_train_epochs from Recipe to {manager.max_epochs}" + ) + kwargs["args"].num_train_epochs = manager.max_epochs + + return manager, arch_managers + + def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): + if ( + not load_path + or not os.path.isdir(load_path) + or not os.path.isfile(os.path.join(load_path, WEIGHTS_NAME)) + ): + _LOGGER.warning( + "Model state was not reloaded for SparseML: " + f"could not find model wieghts for model_path {load_path}" + ) + return + + current_state_dict = self.model.state_dict() + + if set(orig_state_dict.keys()) == set(current_state_dict): + # no change in keys, ignore reload + return + + # change in keys due to architecture changes, reload statedict + load_state_dict = torch.load( + os.path.join(load_path, WEIGHTS_NAME), map_location="cpu" + ) + _, missing, unexpected, __ = self.model._load_state_dict_into_model( + self.model, load_state_dict, load_path, _fast_init=False + ) + + if missing: + _LOGGER.warning( + "Missing keys found when reloading model state for SparseML recipe:" + f"{missing}" + ) + + if unexpected: + _LOGGER.warning( + f"Unexpected keys found when reloading model state for SparseML recipe:" + f"{unexpected}" + ) + + total_loaded = len(current_state_dict) - (len(missing) if len(missing) else 0) + _LOGGER.info( + f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}" + ) + SparseAutoModel.log_model_load( + self.model, + self.model_state_path, + model_type="student" if self.teacher else "model", + delayed_load=False, + ) + + +class TrainerInterface(RecipeManagerTrainerInterface): + """ + Training interface for running sparsification recipes with transformers flows. + Mimics the lifecycle of transformers Trainer classes. + + Should be instantiated with multi-inheretance with a custom trainer class. + TrainerInterface must be provided before Trainer for proper class dependency. + i.e. class MyCustomTrainer(TrainerInterface, Trainer) + + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation + :param kwargs: key word arguments passed to the parent class + """ + + def __init__( + self, + model: Module, + model_state_path: str, + recipe: Optional[str], + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Union[Module, str]] = None, + **kwargs, + ): + super().__init__( + model=model, + model_state_path=model_state_path, + recipe=recipe, + recipe_args=recipe_args, + teacher=teacher, + **kwargs, + ) + + def train(self, *args, **kwargs): + """ + Run a sparsification training cycle. + Calls into apply_manager before super().train() + and calls finalize_manager, if applied, after super().train(). + + :param args: positional args to pass to super().train() + :param kwargs: keyword args to pass to super().train() + :return: the output from super.train() + """ + checkpoint, epoch = self._generate_apply_manager_params(kwargs) + applied = self.apply_manager(epoch=epoch, checkpoint=checkpoint) + self.callback_disable_fp16.check_disable(epoch, force=True) + output = super().train(*args, **kwargs) + if applied: + self.finalize_manager() + self.log_model_sparsification() + + return output + + def evaluate(self, *args, **kwargs): + """ + Run a sparsification evaluation cycle. + Calls into apply_manager before super().evaluate() + and calls finalize_manager, if applied, after super().evaluate(). + + :param args: positional args to pass to super().evaluate() + :param kwargs: keyword args to pass to super().evaluate() + :return: the output from super.evaluate() + """ + applied = self.apply_manager(epoch=math.inf, checkpoint=None) + output = super().evaluate(*args, **kwargs) + if applied: + self.finalize_manager() + + return output + + def predict(self, *args, **kwargs): + """ + Run a sparsification prediction cycle. + Calls into apply_manager before super().predict() + and calls finalize_manager, if applied, after super().predict(). + + :param args: positional args to pass to super().predict() + :param kwargs: keyword args to pass to super().predict() + :return: the output from super.predict() + """ + applied = self.apply_manager(epoch=math.inf, checkpoint=None) + output = super().predict(*args, **kwargs) + if applied: + self.finalize_manager() + + return output + + def _generate_apply_manager_params(self, kwargs) -> Tuple[Optional[str], float]: + checkpoint = None + epoch = 0.0 + + if not kwargs or "resume_from_checkpoint" not in kwargs: + _LOGGER.warning( + "resume_from_checkpoint not passed into SparseMLTrainer.train. " + "This will cause issues with restoring recipes when " + "running from a checkpoint." + ) + elif kwargs["resume_from_checkpoint"]: + if ( + isinstance(kwargs["resume_from_checkpoint"], bool) + and kwargs["resume_from_checkpoint"] + ): + checkpoint = get_last_checkpoint(self.args.output_dir) + else: + checkpoint = kwargs["resume_from_checkpoint"] + epoch = TrainerState.load_from_json( + os.path.join(checkpoint, TRAINER_STATE_NAME) + ).epoch + + return checkpoint, epoch + + +class Trainer(TrainerInterface, TransformersTrainer): + """ + Training implementation for running sparsification recipes with transformers flows. + + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation + :param kwargs: key word arguments passed to the parent class + """ + + def __init__( + self, + model: Module, + model_state_path: str, + recipe: Optional[str], + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Union[Module, str]] = None, + **kwargs, + ): + super().__init__( + model=model, + model_state_path=model_state_path, + recipe=recipe, + recipe_args=recipe_args, + teacher=teacher, + **kwargs, + ) + + +class DisableHalfPrecisionCallback(TrainerCallback): + """ + TrainerCallback for disabling FP16 training before QAT training begins + + :param sparseml_trainer: SparseML trainer that will call back into this object + :param args: args to be passed to base TrainerCallback + :param kwargs: key word arguments to be passed to base TrainerCallback + """ + + def __init__(self, trainer: RecipeManagerTrainerInterface, *args, **kwargs): + super().__init__(*args, **kwargs) + self.trainer = trainer + self.on_begin_called = False + self.quant_start_epoch = math.inf + + def check_disable(self, epoch: float, force: bool = False): + if ( + force or hasattr(self.trainer, "scaler") and self.trainer.scaler._enabled + ) and self.qat_active(epoch): + self.disable_amp(epoch) + + def qat_active(self, epoch: float) -> bool: + return (self.trainer.manager and self.trainer.manager.qat_active(epoch)) or any( + bool(man.quantization_modifiers) for man in self.trainer.arch_managers + ) + + def disable_amp(self, epoch: float): + if not self.on_begin_called: + # disable if training loops haven't started so we don't load + # the empty scaler state dict and instead disable it from the start + self.trainer.use_amp = False + + if hasattr(self.trainer, "scaler"): + self.trainer.scaler._enabled = False + + self.quant_start_epoch = epoch + _LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training") + + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the beginning of an epoch. Disables + """ + super().on_epoch_begin(args, state, control, **kwargs) + self.on_begin_called = True + self.check_disable(state.epoch) + + if state.epoch > self.quant_start_epoch: + _LOGGER.info(self.trainer.model) diff --git a/src/sparseml/transformers/train/text_classification.py b/src/sparseml/transformers/text_classification.py similarity index 93% rename from src/sparseml/transformers/train/text_classification.py rename to src/sparseml/transformers/text_classification.py index b8bca8e3818..edafb5e2f45 100644 --- a/src/sparseml/transformers/train/text_classification.py +++ b/src/sparseml/transformers/text_classification.py @@ -36,7 +36,6 @@ from datasets import load_dataset, load_metric from transformers import ( AutoConfig, - AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, EvalPrediction, @@ -49,11 +48,8 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version -from sparseml.transformers.utils import ( - SparseMLGLUETrainer, - load_recipe, - preprocess_state_dict, -) +from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.utils import SparseAutoModel # Will error if the minimal version of Transformers is not installed. @@ -416,11 +412,6 @@ def main(): # In distributed training, the .from_pretrained methods guarantee that only one # local process can concurrently download model & vocab. - # Load and preprocess the state dict if the model existed (in this case we continue - # to train or evaluate the model). The preprocessing step is to restore names of - # parameters changed by QAT process - state_dict = preprocess_state_dict(model_args.model_name_or_path) - config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name @@ -440,28 +431,25 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - state_dict=state_dict, + model, teacher = SparseAutoModel.text_classification_from_pretrained_distil( + model_name_or_path=( + model_args.tokenizer_name + if model_args.tokenizer_name + else model_args.model_name_or_path + ), + model_kwargs={ + "config": config, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + }, + teacher_name_or_path=model_args.distill_teacher, + teacher_kwargs={ + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + }, ) - teacher_model = None - if model_args.distill_teacher is not None: - teacher_model = AutoModelForSequenceClassification.from_pretrained( - model_args.distill_teacher, - from_tf=bool(".ckpt" in model_args.distill_teacher), - cache_dir=model_args.cache_dir, - ) - teacher_model_parameters = filter( - lambda p: p.requires_grad, teacher_model.parameters() - ) - params = sum([np.prod(p.size()) for p in teacher_model_parameters]) - _LOGGER.info("Teacher Model has %s parameters", params) # Preprocessing the datasets if data_args.task_name is not None: sentence1_key, sentence2_key = _TASK_TO_KEYS[data_args.task_name] @@ -617,29 +605,20 @@ def compute_metrics(p: EvalPrediction): else: data_collator = None - # Load possible existing recipe and new one passed in through command argument - existing_recipe = load_recipe(model_args.model_name_or_path) - new_recipe = data_args.recipe - # Initialize our Trainer - trainer = SparseMLGLUETrainer( - model_args.model_name_or_path, - new_recipe, - checkpoint_recipes=[existing_recipe], - teacher=teacher_model, + trainer = Trainer( model=model, + model_state_path=model_args.model_name_or_path, + recipe=data_args.recipe, + recipe_args=data_args.recipe_args, + teacher=teacher, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, - recipe_args=data_args.recipe_args, ) - # Apply recipes to the model. This is necessary given that - # sparsification methods such as QAT modified the model graph with their own - # learnable parameters. They are also restored/loaded to the model. - trainer.apply_recipes() # Training if training_args.do_train: diff --git a/src/sparseml/transformers/train/token_classification.py b/src/sparseml/transformers/token_classification.py similarity index 92% rename from src/sparseml/transformers/train/token_classification.py rename to src/sparseml/transformers/token_classification.py index 3323ea33b17..d4e6b64d16e 100644 --- a/src/sparseml/transformers/train/token_classification.py +++ b/src/sparseml/transformers/token_classification.py @@ -34,7 +34,6 @@ from datasets import ClassLabel, load_dataset, load_metric from transformers import ( AutoConfig, - AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, HfArgumentParser, @@ -45,11 +44,8 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version -from sparseml.transformers.utils import ( - SparseMLNERTrainer, - load_recipe, - preprocess_state_dict, -) +from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.utils import SparseAutoModel # Will error if the minimal version of Transformers is not installed. @@ -370,11 +366,6 @@ def get_label_list(labels): # The .from_pretrained methods guarantee that only one local process can # concurrently download model & vocab. - # Load and preprocess the state dict if the model existed (in this case we continue - # to train or evaluate the model). The preprocessing step is to restore names of - # parameters changed by QAT process. - state_dict = preprocess_state_dict(model_args.model_name_or_path) - config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name @@ -394,29 +385,25 @@ def get_label_list(labels): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - model = AutoModelForTokenClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - state_dict=state_dict, + model, teacher = SparseAutoModel.token_classification_from_pretrained_distil( + model_name_or_path=( + model_args.tokenizer_name + if model_args.tokenizer_name + else model_args.model_name_or_path + ), + model_kwargs={ + "config": config, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + }, + teacher_name_or_path=model_args.distill_teacher, + teacher_kwargs={ + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + }, ) - teacher_model = None - if model_args.distill_teacher is not None: - teacher_model = AutoModelForTokenClassification.from_pretrained( - model_args.distill_teacher, - from_tf=bool(".ckpt" in model_args.distill_teacher), - cache_dir=model_args.cache_dir, - ) - teacher_model_parameters = filter( - lambda p: p.requires_grad, teacher_model.parameters() - ) - params = sum([np.prod(p.size()) for p in teacher_model_parameters]) - _LOGGER.info("Teacher Model has %s parameters", params) - # Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( @@ -549,31 +536,21 @@ def compute_metrics(p): "accuracy": results["overall_accuracy"], } - # Load possible existing recipe and new one passed in through command argument - existing_recipe = load_recipe(model_args.model_name_or_path) - new_recipe = data_args.recipe - # Initialize our Trainer - trainer = SparseMLNERTrainer( - model_args.model_name_or_path, - new_recipe, - checkpoint_recipes=[existing_recipe], - teacher=teacher_model, + trainer = Trainer( model=model, + model_state_path=model_args.model_name_or_path, + recipe=data_args.recipe, + recipe_args=data_args.recipe_args, + teacher=teacher, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, - recipe_args=data_args.recipe_args, ) - # Apply recipes to the model. This is necessary given that - # sparsification methods such as QAT modified the model graph with their own - # learnable parameters. They are also restored/loaded to the model. - trainer.apply_recipes() - # Training if training_args.do_train: checkpoint = None diff --git a/src/sparseml/transformers/train/language_modeling.py b/src/sparseml/transformers/train/language_modeling.py deleted file mode 100644 index e59112fe904..00000000000 --- a/src/sparseml/transformers/train/language_modeling.py +++ /dev/null @@ -1,700 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Team All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Adapted from https://github.com/huggingface/transformers -# neuralmagic: no copyright - -""" -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) -on a text file or a dataset - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=masked-lm -""" - -# You can also adapt this script on your own masked language modeling task. -# Pointers for this are left as comments - -import logging -import math -import os -import sys -from dataclasses import dataclass, field -from typing import Optional - -import numpy -import transformers -from datasets import concatenate_datasets, load_dataset -from transformers import ( - CONFIG_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - AutoConfig, - AutoModelForMaskedLM, - AutoTokenizer, - DataCollatorForLanguageModeling, - HfArgumentParser, - TrainingArguments, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version - -from sparseml.transformers.utils import SparseMLMLMTrainer, load_recipe - - -# Will error if the minimal version of Transformers is not installed. -# Remove at your own risks -check_min_version("4.7.0.dev0") - -_LOGGER = logging.getLogger(__name__) -MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, - or train from scratch - """ - - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": "The model checkpoint for weights initialization." - "Don't set if you want to train a model from scratch." - }, - ) - model_type: Optional[str] = field( - default=None, - metadata={ - "help": "If training from scratch, pass a model type from the list: " - + ", ".join(MODEL_TYPES) - }, - ) - distill_teacher: Optional[str] = field( - default=None, - metadata={"help": "Teacher model which needs to be a trained QA model"}, - ) - config_name: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained config name or path if not the same as model_name" - }, - ) - tokenizer_name: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained tokenizer name or path if not the same as model_name" - }, - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pretrained models from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizers. Default True"}, - ) - model_revision: str = field( - default="main", - metadata={ - "help": "The specific model version to use " - "(can be a branch name, tag name or commit id)" - }, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": "Will use token generated when running `transformers-cli login` " - "(necessary to use this script with private models)" - }, - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for - training and eval - """ - - recipe: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to a SparseML sparsification recipe, see " - "https://github.com/neuralmagic/sparseml for more information" - ), - }, - ) - recipe_args: Optional[str] = field( - default=None, - metadata={"help": "Recipe arguments to be overwritten"}, - ) - dataset_name: Optional[str] = field( - default=None, - metadata={"help": "The name of the dataset to use (via the datasets library)"}, - ) - dataset_config_name: Optional[str] = field( - default=None, - metadata={ - "help": ("The configuration name of the dataset to use"), - }, - ) - - # An extra second dataset - dataset_name_2: Optional[str] = field( - default=None, - metadata={"help": "The name of the dataset to use (via the datasets library)"}, - ) - dataset_config_name_2: Optional[str] = field( - default=None, - metadata={ - "help": ("The configuration name of the dataset to use"), - }, - ) - - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a text file)."} - ) - validation_file: Optional[str] = field( - default=None, - metadata={ - "help": ( - "An optional input evaluation data file to evaluate the perplexity on" - "(a text file)." - ), - }, - ) - overwrite_cache: bool = field( - default=False, - metadata={"help": "Overwrite the cached training and evaluation sets"}, - ) - validation_split_percentage: Optional[int] = field( - default=5, - metadata={ - "help": ( - "The percentage of the train set used as validation set in case " - "there's no validation split" - ) - }, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated." - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - mlm_probability: float = field( - default=0.15, - metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}, - ) - line_by_line: bool = field( - default=False, - metadata={ - "help": ( - "Whether distinct lines of text in the dataset are to be handled as " - "distinct sequences." - ), - }, - ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to " - "the maximum length in the batch." - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of training examples to this value if set." - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of evaluation examples to this value if set." - }, - ) - - def __post_init__(self): - if ( - self.dataset_name is None - and self.train_file is None - and self.validation_file is None - ): - raise ValueError( - "Need either a dataset name or a training/validation file." - ) - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in [ - "csv", - "json", - "txt", - ], "`train_file` should be a csv, a json or a txt file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension in [ - "csv", - "json", - "txt", - ], "`validation_file` should be a csv, a json or a txt file." - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) - ) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) - ) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # Detecting last checkpoint. - last_checkpoint = None - if ( - os.path.isdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and (len(os.listdir(training_args.output_dir)) > 0): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and " - "is not empty. Use --overwrite_output_dir to overcome." - ) - elif ( - last_checkpoint is not None and training_args.resume_from_checkpoint is None - ): - _LOGGER.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. " - "To avoid this behavior, change the `--output_dir` or add " - "`--overwrite_output_dir` to train from scratch." - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - _LOGGER.setLevel(logging.INFO if training_args.should_log else logging.WARN) - - # Log on each process the small summary: - _LOGGER.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}" - f", n_gpu: {training_args.n_gpu} distributed training: " - f"{bool(training_args.local_rank != -1)}, 16-bits training: " - f"{training_args.fp16}" - ) - # Set the verbosity to info of the Transformers _LOGGER (on main process only): - if training_args.should_log: - transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - _LOGGER.info(f"Training/evaluation parameters {training_args}") - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and - # evaluation files (see below) or just provide the name of one of the public - # datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub - # - # For CSV/JSON files, this script will use the column called 'text' or the - # first column. You can easily tweak this behavior (see below) - # - # In distributed training, the load_dataset function guarantee that only one - # local process can concurrently download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - ) - if "validation" not in datasets.keys(): - datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - ) - datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - ) - else: - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - datasets = load_dataset( - extension, data_files=data_files, cache_dir=model_args.cache_dir - ) - # See more about loading any type of standard or custom dataset (from files, - # python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load extra dataset if specified, and concatenate with the original one - if data_args.dataset_name_2 is not None: - # Downloading and loading a dataset from the hub. - datasets_2 = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - cache_dir=model_args.cache_dir, - ) - if "validation" not in datasets_2.keys(): - datasets_2["validation"] = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - ) - datasets_2["train"] = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - ) - # Concatenate two datasets - if datasets is not None: - for split in ["validation", "train"]: - datasets[split] = concatenate_datasets( - [datasets[split], datasets_2[split]] - ) - - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can - # concurrently download model & vocab. - config_kwargs = { - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.config_name: - config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) - elif model_args.model_name_or_path: - config = AutoConfig.from_pretrained( - model_args.model_name_or_path, **config_kwargs - ) - else: - config = CONFIG_MAPPING[model_args.model_type]() - _LOGGER.warning("You are instantiating a new config instance from scratch.") - - tokenizer_kwargs = { - "cache_dir": model_args.cache_dir, - "use_fast": model_args.use_fast_tokenizer, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name, **tokenizer_kwargs - ) - elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, **tokenizer_kwargs - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not " - "supported by this script. You can do it from another script, save " - "it, and load it from here, using --tokenizer_name." - ) - if model_args.model_name_or_path: - model = AutoModelForMaskedLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - _LOGGER.info("Training new model from scratch") - model = AutoModelForMaskedLM.from_config(config) - - model.resize_token_embeddings(len(tokenizer)) - - teacher_model = None - if model_args.distill_teacher is not None: - teacher_model = AutoModelForMaskedLM.from_pretrained( - model_args.distill_teacher, - from_tf=bool(".ckpt" in model_args.distill_teacher), - cache_dir=model_args.cache_dir, - ) - teacher_model_parameters = filter( - lambda p: p.requires_grad, teacher_model.parameters() - ) - params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) - _LOGGER.info("Teacher Model has %s parameters", params) - - # Preprocessing the datasets. - # First we tokenize all the texts. - if training_args.do_train: - column_names = datasets["train"].column_names - else: - column_names = datasets["validation"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - _LOGGER.warning( - "The tokenizer picked seems to have a very large `model_max_length`" - f"({tokenizer.model_max_length}). Picking 1024 instead. You can " - "change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if data_args.max_seq_length > tokenizer.model_max_length: - _LOGGER.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) " - "is larger than the maximum length for the model " - f"({tokenizer.model_max_length}). Using " - f"max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - - if data_args.line_by_line: - # When using line_by_line, we just tokenize each nonempty line. - padding = "max_length" if data_args.pad_to_max_length else False - - def tokenize_function(examples): - # Remove empty lines - examples["text"] = [ - line - for line in examples["text"] - if len(line) > 0 and not line.isspace() - ] - return tokenizer( - examples["text"], - padding=padding, - truncation=True, - max_length=max_seq_length, - # We use this option because DataCollatorForLanguageModeling (see - # below) is more efficient when it receives the `special_tokens_mask`. - return_special_tokens_mask=True, - ) - - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not data_args.overwrite_cache, - ) - else: - # Otherwise, we tokenize every text, then concatenate them together before - # splitting them in smaller parts. We use `return_special_tokens_mask=True` - # because DataCollatorForLanguageModeling (see below) is more - # efficient when it receives the `special_tokens_mask`. - def tokenize_function(examples): - return tokenizer( - examples[text_column_name], return_special_tokens_mask=True - ) - - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - ) - - # Main data processing function that will concatenate all texts from our - # dataset and generate chunks of max_seq_length. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported - # it instead of this drop, you can customize this part to your needs - total_length = (total_length // max_seq_length) * max_seq_length - # Split by chunks of max_len. - result = { - k: [ - t[i : i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] - for k, t in concatenated_examples.items() - } - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, - # so group_texts throws away a remainder for each of those groups of 1,000 - # texts. You can adjust that batch_size here but a higher value - # might be slower to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of - # the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - tokenized_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - ) - - if training_args.do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = tokenized_datasets["train"] - if data_args.max_train_samples is not None: - train_dataset = train_dataset.select(range(data_args.max_train_samples)) - - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = tokenized_datasets["validation"] - if data_args.max_eval_samples is not None: - eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) - - # Data collator - # This one will take care of randomly masking the tokens. - pad_to_multiple_of_8 = ( - data_args.line_by_line - and training_args.fp16 - and not data_args.pad_to_max_length - ) - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=data_args.mlm_probability, - pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, - ) - - # Load possible existing recipe and new one passed in through command argument - existing_recipe = load_recipe(model_args.model_name_or_path) - new_recipe = data_args.recipe - - compute_metrics = None - # Initialize our Trainer - trainer = SparseMLMLMTrainer( - model_args.model_name_or_path, - new_recipe, - checkpoint_recipes=[existing_recipe], - teacher=teacher_model, - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - compute_metrics=compute_metrics, - recipe_args=data_args.recipe_args, - ) - - # Apply recipes to the model. This is necessary given that - # sparsification methods such as QAT modified the model graph with their own - # learnable parameters. They are also restored/loaded to the model. - trainer.apply_recipes() - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() # Saves the tokenizer too for easy upload - metrics = train_result.metrics - - max_train_samples = ( - data_args.max_train_samples - if data_args.max_train_samples is not None - else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - _LOGGER.info("*** Evaluate ***") - - metrics = trainer.evaluate() - - max_eval_samples = ( - data_args.max_eval_samples - if data_args.max_eval_samples is not None - else len(eval_dataset) - ) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"} - if data_args.dataset_name is not None: - kwargs["dataset_tags"] = data_args.dataset_name - if data_args.dataset_config_name is not None: - kwargs["dataset_args"] = data_args.dataset_config_name - kwargs[ - "dataset" - ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" - else: - kwargs["dataset"] = data_args.dataset_name - - trainer.push_to_hub(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/src/sparseml/transformers/utils/__init__.py b/src/sparseml/transformers/utils/__init__.py index bb23d6248fc..36001d2d711 100644 --- a/src/sparseml/transformers/utils/__init__.py +++ b/src/sparseml/transformers/utils/__init__.py @@ -13,15 +13,10 @@ # limitations under the License. """ -Tools for integrating SparseML with transformers training flows +Utilities for applying sparsification algorithms to Hugging Face transformers flows """ # flake8: noqa -from .export import * from .helpers import * -from .language_modeling import * -from .question_answering import * -from .text_classification import * -from .token_classification import * -from .trainer import * +from .model import * diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index 32aa3af5d82..17d8ea7a725 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -17,78 +17,13 @@ flows """ - -import os -from typing import Any, Dict - -import torch -from transformers.file_utils import WEIGHTS_NAME - -from sparseml.pytorch.optim.manager import ScheduledModifierManager - - __all__ = [ "RECIPE_NAME", - "preprocess_state_dict", - "load_recipe", + "RECIPE_REGEX", + "RECIPE_TEMPLATE", ] RECIPE_NAME = "recipe.yaml" - - -def load_recipe(pretrained_model_name_or_path: str) -> str: - """ - Get path to recipe from the model directory - - :param pretrained_model_name_or_path: path to model directory - :return: path to recipe - """ - recipe = None - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): - recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) - return recipe - - -def preprocess_state_dict(pretrained_model_name_or_path: str) -> Dict[str, Any]: - """ - Restore original parameter names that were changed by QAT process - - :param pretrained_model_name_or_path: name or path to model - """ - state_dict = None - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): - recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) - manager = ScheduledModifierManager.from_yaml(recipe) - modifiers = [m.__class__.__name__ for m in manager.modifiers] - is_qat_recipe = "QuantizationModifier" in modifiers - else: - is_qat_recipe = False - if os.path.isfile( - os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - ): - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - removed_keys = ( - [ - key - for key in state_dict - if ( - key.endswith(".module.weight") - or key.endswith(".module.bias") - ) - ] - if is_qat_recipe - else [] - ) - for key in removed_keys: - new_key = key.replace(".module", "") - state_dict[new_key] = state_dict[key] - state_dict.pop(key) - return state_dict +RECIPE_REGEX = r"recipe*.yaml" +RECIPE_TEMPLATE = "recipe{}.yaml" diff --git a/src/sparseml/transformers/utils/language_modeling.py b/src/sparseml/transformers/utils/language_modeling.py deleted file mode 100644 index a3247932386..00000000000 --- a/src/sparseml/transformers/utils/language_modeling.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Training utilities for text classification / GLUE tasks -""" - -from typing import Any, Dict, List, Optional, Union - -import torch -from transformers import Trainer - -from sparseml.transformers.utils.trainer import SparseMLTrainer - - -__all__ = ["SparseMLMLMTrainer"] - - -class SparseMLMLMTrainer(SparseMLTrainer, Trainer): - """ - Trainer for running sparsification recipes with MLM training - - :param model_name_or_path: path to model directory to be trained - :param recipe: path to recipe for model sparsification - :param checkpoint_recipes: list of paths to recipes used to train the - starting checkpoint for this training run. Will be applied to the model - on call to `apply_recipes` so that model state can be reproduced for - weight loading - :param teacher: teacher model for distillation. Default is None - :param recipe_args: Dictionary of recipe variables to override or json - loadable string of those args. Default is None - :param args: arguments passed into parent class - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_name_or_path: str, - recipe: str, - checkpoint_recipes: Union[str, List[str]] = None, - teacher: Optional[torch.nn.Module] = None, - recipe_args: Union[Dict[str, Any], str] = None, - *args, - **kwargs, - ): - super().__init__( - model_name_or_path=model_name_or_path, - recipe=recipe, - checkpoint_recipes=checkpoint_recipes, - teacher=teacher, - recipe_args=recipe_args, - teacher_input_keys=None, - *args, - **kwargs, - ) diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py new file mode 100644 index 00000000000..28aa8429c58 --- /dev/null +++ b/src/sparseml/transformers/utils/model.py @@ -0,0 +1,383 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch.nn import Module +from transformers import ( + AutoModelForMaskedLM, + AutoModelForQuestionAnswering, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, +) +from transformers.file_utils import WEIGHTS_NAME + +from sparseml.pytorch.utils import ModuleSparsificationInfo + + +__all__ = ["SparseAutoModel"] + + +_LOGGER = logging.getLogger(__name__) + + +class SparseAutoModel: + """ + Factory class for creating sparse models using transformers AutoModel classes + """ + + @staticmethod + def masked_language_modeling_from_pretrained( + model_name_or_path: str, + model_type: str, + **kwargs, + ) -> Module: + """ + :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param kwargs: keyword arguments to pass through to the AutoModel call + :return: the created model for masked language modeling + """ + delayed = False + if not model_name_or_path: + _LOGGER.info("Training new model from scratch") + config = kwargs["config"] + model = AutoModelForMaskedLM.from_config(config) + else: + SparseAutoModel._check_tf(model_name_or_path) + if not kwargs: + kwargs = {} + kwargs["from_tf"] = False + if "state_dict" not in kwargs: + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( + model_name_or_path + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_name_or_path, + **kwargs, + ) + + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model + + @staticmethod + def masked_language_modeling_from_pretrained_distil( + model_name_or_path: str, + teacher_name_or_path: Optional[str], + model_kwargs: Dict[str, Any], + teacher_kwargs: Dict[str, Any], + ) -> Tuple[Module, Optional[Union[Module, str]]]: + """ + :param model_name_or_path: the name of or path to the model to load + :param teacher_name_or_path: the name of or path to the teacher to load, + None or one of ['self', 'disable'] will not create a teacher and + instead return the value passed in + :param model_kwargs: the keyword args to pass into the AutoModel for model + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher + :return: a tuple containing the model and distillation teacher (optional) + for masked language modeling + """ + model = SparseAutoModel.masked_language_modeling_from_pretrained( + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, + ) + teacher = ( + SparseAutoModel.masked_language_modeling_from_pretrained( + teacher_name_or_path, + model_type="teacher", + **teacher_kwargs, + ) + if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] + else teacher_name_or_path + ) + + return model, teacher + + @staticmethod + def question_answering_from_pretrained( + model_name_or_path: str, + model_type: str, + **kwargs, + ) -> Module: + """ + :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param kwargs: keyword arguments to pass through to the AutoModel call + :return: the created model for question answering + """ + SparseAutoModel._check_tf(model_name_or_path) + if not kwargs: + kwargs = {} + kwargs["from_tf"] = False + delayed = False + if "state_dict" not in kwargs: + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( + model_name_or_path + ) + model = AutoModelForQuestionAnswering.from_pretrained( + model_name_or_path, + **kwargs, + ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model + + @staticmethod + def question_answering_from_pretrained_distil( + model_name_or_path: str, + teacher_name_or_path: Optional[str], + model_kwargs: Dict[str, Any], + teacher_kwargs: Dict[str, Any], + ) -> Tuple[Module, Optional[Union[Module, str]]]: + """ + :param model_name_or_path: the name of or path to the model to load + :param teacher_name_or_path: the name of or path to the teacher to load, + None or one of ['self', 'disable'] will not create a teacher and + instead return the value passed in + :param model_kwargs: the keyword args to pass into the AutoModel for model + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher + :return: a tuple containing the model and distillation teacher (optional) + for question answering + """ + model = SparseAutoModel.question_answering_from_pretrained( + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, + ) + teacher = ( + SparseAutoModel.question_answering_from_pretrained( + teacher_name_or_path, model_type="teacher", **teacher_kwargs + ) + if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] + else teacher_name_or_path + ) + + return model, teacher + + @staticmethod + def text_classification_from_pretrained( + model_name_or_path: str, + model_type: str, + **kwargs, + ) -> Module: + """ + :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param kwargs: keyword arguments to pass through to the AutoModel call + :return: the created model for text classification + """ + SparseAutoModel._check_tf(model_name_or_path) + if not kwargs: + kwargs = {} + kwargs["from_tf"] = False + delayed = False + if "state_dict" not in kwargs: + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( + model_name_or_path + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_name_or_path, + **kwargs, + ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model + + @staticmethod + def text_classification_from_pretrained_distil( + model_name_or_path: str, + teacher_name_or_path: Optional[str], + model_kwargs: Dict[str, Any], + teacher_kwargs: Dict[str, Any], + ) -> Tuple[Module, Optional[Module]]: + """ + :param model_name_or_path: the name of or path to the model to load + :param teacher_name_or_path: the name of or path to the teacher to load, + None or one of ['self', 'disable'] will not create a teacher and + instead return the value passed in + :param model_kwargs: the keyword args to pass into the AutoModel for model + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher + :return: a tuple containing the model and distillation teacher (optional) + for sequence/text classification + """ + model = SparseAutoModel.text_classification_from_pretrained( + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, + ) + teacher = ( + SparseAutoModel.text_classification_from_pretrained( + teacher_name_or_path, model_type="teacher", **teacher_kwargs + ) + if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] + else teacher_name_or_path + ) + + return model, teacher + + @staticmethod + def token_classification_from_pretrained( + model_name_or_path: str, + model_type: str, + **kwargs, + ) -> Module: + """ + :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param kwargs: keyword arguments to pass through to the AutoModel call + :return: the created model for token classification + """ + SparseAutoModel._check_tf(model_name_or_path) + if not kwargs: + kwargs = {} + kwargs["from_tf"] = False + delayed = False + if "state_dict" not in kwargs: + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( + model_name_or_path + ) + model = AutoModelForTokenClassification.from_pretrained( + model_name_or_path, + **kwargs, + ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model + + @staticmethod + def token_classification_from_pretrained_distil( + model_name_or_path: str, + teacher_name_or_path: Optional[str], + model_kwargs: Dict[str, Any], + teacher_kwargs: Dict[str, Any], + ) -> Tuple[Module, Optional[Module]]: + """ + :param model_name_or_path: the name of or path to the model to load + :param teacher_name_or_path: the name of or path to the teacher to load, + None or one of ['self', 'disable'] will not create a teacher and + instead return the value passed in + :param model_kwargs: the keyword args to pass into the AutoModel for model + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher + :return: a tuple containing the model and distillation teacher (optional) + for token classification + """ + model = SparseAutoModel.token_classification_from_pretrained( + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, + ) + teacher = ( + SparseAutoModel.token_classification_from_pretrained( + teacher_name_or_path, model_type="teacher", **teacher_kwargs + ) + if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] + else teacher_name_or_path + ) + + return model, teacher + + @staticmethod + def log_model_load( + model: Module, model_name_or_path: str, model_type: str, delayed_load: bool + ): + """ + Log the state of a loaded model including sparsity and + prunable params information. + + :param model: the loaded model + :param model_name_or_path: the original name of or path to the model that loaded + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param delayed_load: True if this model load was delayed until after + recipe instantiation due to QAT or other architectural state changes + """ + if delayed_load: + _LOGGER.info( + f"Delayed load of model {model_name_or_path} detected. " + f"Will print out model information once SparseML recipes have loaded" + ) + return + + sparsification_info = ModuleSparsificationInfo(model) + + _LOGGER.info( + f"Loaded {model_type} from {model_name_or_path} " + f"with {sparsification_info.params_total} total params. " + f"Of those there are {sparsification_info.params_prunable_total} prunable " + f"params which have {sparsification_info.params_prunable_sparse_percent} " + "avg sparsity." + ) + model_type = ( + "sparse" + if sparsification_info.params_prunable_sparse_percent > 5 + else "dense" + ) + _LOGGER.info( + f"{model_type} model detected, " + f"all sparsification info: {sparsification_info}" + ) + + @staticmethod + def _loadable_state_dict( + model_name_or_path: str, + ) -> Tuple[Optional[Dict[str, Any]], bool]: + """ + :param model_name_or_path: name of or path to model + :return: (loaded state dict, True if overriding state dict for delayed load) + delayed load happens when a QAT graph is detected since a recipe + must be applied first + """ + if not model_name_or_path or not os.path.isfile( + os.path.join(model_name_or_path, WEIGHTS_NAME) + ): + return None, False + + state_dict = torch.load( + os.path.join(model_name_or_path, WEIGHTS_NAME), map_location="cpu" + ) + is_qat_state = any( + [ + key.endswith(".zero_point") or key.endswith(".observer_enabled") + for key in state_dict.keys() + ] + ) + + if not is_qat_state: + return None, False + + _LOGGER.warning( + "QAT state detected, ignore any loading errors, weights will reload " + f"after SparseML recipes have been applied {model_name_or_path}" + ) + + return {}, True + + @staticmethod + def _check_tf(model_name_or_path: str): + if ".ckpt" in model_name_or_path: + raise ValueError( + "PyTorch is the only supported model type currently for SparseML " + "HuggingFace Transformers integration. " + "Detected a TensorFlow model from model_name_or_path: " + f"{model_name_or_path}" + ) diff --git a/src/sparseml/transformers/utils/text_classification.py b/src/sparseml/transformers/utils/text_classification.py deleted file mode 100644 index 579955f7a5f..00000000000 --- a/src/sparseml/transformers/utils/text_classification.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Training utilities for text classification / GLUE tasks -""" - -from typing import Any, Dict, List, Optional, Union - -import torch -from transformers import Trainer - -from sparseml.transformers.utils.trainer import SparseMLTrainer - - -__all__ = ["SparseMLGLUETrainer"] - - -class SparseMLGLUETrainer(SparseMLTrainer, Trainer): - """ - Trainer for running sparsification recipes with GLUE training - - :param model_name_or_path: path to model directory to be trained - :param recipe: path to recipe for model sparsification - :param checkpoint_recipes: list of paths to recipes used to train the - starting checkpoint for this training run. Will be applied to the model - on call to `apply_recipes` so that model state can be reproduced for - weight loading - :param teacher: teacher model for distillation. Default is None - :param recipe_args: Dictionary of recipe variables to override or json - loadable string of those args. Default is None - :param args: arguments passed into parent class - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_name_or_path: str, - recipe: str, - checkpoint_recipes: Union[str, List[str]] = None, - teacher: Optional[torch.nn.Module] = None, - recipe_args: Union[Dict[str, Any], str] = None, - *args, - **kwargs, - ): - super().__init__( - model_name_or_path=model_name_or_path, - recipe=recipe, - checkpoint_recipes=checkpoint_recipes, - teacher=teacher, - recipe_args=recipe_args, - teacher_input_keys=["input_ids", "token_type_ids", "attention_mask"], - *args, - **kwargs, - ) diff --git a/src/sparseml/transformers/utils/token_classification.py b/src/sparseml/transformers/utils/token_classification.py deleted file mode 100644 index f613338d210..00000000000 --- a/src/sparseml/transformers/utils/token_classification.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Training utilities for text classification / GLUE tasks -""" - -from typing import Any, Dict, List, Optional, Union - -import torch -from transformers import Trainer - -from sparseml.transformers.utils.trainer import SparseMLTrainer - - -__all__ = ["SparseMLNERTrainer"] - - -class SparseMLNERTrainer(SparseMLTrainer, Trainer): - """ - Trainer for running sparsification recipes with NER training - - :param model_name_or_path: path to model directory to be trained - :param recipe: path to recipe for model sparsification - :param checkpoint_recipes: list of paths to recipes used to train the - starting checkpoint for this training run. Will be applied to the model - on call to `apply_recipes` so that model state can be reproduced for - weight loading - :param teacher: teacher model for distillation. Default is None - :param recipe_args: Dictionary of recipe variables to override or json - loadable string of those args. Default is None - :param args: arguments passed into parent class - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_name_or_path: str, - recipe: str, - checkpoint_recipes: Union[str, List[str]] = None, - teacher: Optional[torch.nn.Module] = None, - recipe_args: Union[Dict[str, Any], str] = None, - *args, - **kwargs, - ): - super().__init__( - model_name_or_path=model_name_or_path, - recipe=recipe, - checkpoint_recipes=checkpoint_recipes, - teacher=teacher, - recipe_args=recipe_args, - teacher_input_keys=None, - *args, - **kwargs, - ) diff --git a/src/sparseml/transformers/utils/trainer.py b/src/sparseml/transformers/utils/trainer.py deleted file mode 100644 index b824e6821df..00000000000 --- a/src/sparseml/transformers/utils/trainer.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -SparseML transformers trainer class to be plugged in with existing HF trainer flows -""" - - -import json -import logging -import math -import os -from typing import Any, Dict, List, Optional, Union - -import torch -from transformers import ( - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments, -) -from transformers.file_utils import WEIGHTS_NAME - -from sparseml.pytorch.optim import LayerPruningModifier, QuantizationModifier -from sparseml.pytorch.optim.manager import ScheduledModifierManager -from sparseml.pytorch.optim.optimizer import ScheduledOptimizer -from sparseml.pytorch.utils import logger -from sparseml.transformers.utils.helpers import RECIPE_NAME - - -__all__ = [ - "SparseMLTrainer", - "DisableHalfPrecisionCallback", -] - - -_LOGGER = logging.getLogger(__name__) - - -class SparseMLTrainer: - """ - Trainer for running sparsification recipes with transformers Trainer flows. - - Should either be used in place of standard transformers Trainer class - or instantiated with multi-inheretance with a custom trainer class. SparesMLTrainer - must be provided before Trainer for proper class dependency resolution - - i.e. class MyCustomTrainer(SparseMLTrainer, Trainer) - - :param model_name_or_path: path to model directory to be trained - :param recipe: path to recipe for model sparsification - :param checkpoint_recipes: list of paths to recipes used to train the - starting checkpoint for this training run. Will be applied to the model - on call to `apply_recipes` so that model state can be reproduced for - weight loading - :param teacher: teacher model for distillation. Default is None - :param recipe_args: Dictionary of recipe variables to override or json - loadable string of those args. Default is None - :param teacher_input_keys: keywords of inputs to select from student inputs dict - to also be passed to a the teacher model. Can be useful to avoid extra - computation in forward pass that is not necessary for distillation. Defaults - to passing all student inputs to teacher - :param args: arguments passed into parent class - :param kwargs: key word arguments passed to the parent class - """ - - def __init__( - self, - model_name_or_path: str, - recipe: str, - checkpoint_recipes: Union[str, List[str]] = None, - teacher: Optional[torch.nn.Module] = None, - recipe_args: Union[Dict[str, Any], str] = None, - teacher_input_keys: Optional[List[str]] = None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.model_name_or_path = str(model_name_or_path) - self.recipe = recipe - self.checkpoint_recipes = list( - [checkpoint_recipes] - if isinstance(checkpoint_recipes, str) - else checkpoint_recipes or [] - ) # List[str] - self.teacher = teacher - if self.teacher is not None: - self.teacher.eval() - self.teacher_input_keys = teacher_input_keys - self.criterion = torch.nn.CrossEntropyLoss() - - if recipe_args is not None: - if isinstance(recipe_args, str): - recipe_args = json.loads(recipe_args) - if not isinstance(recipe_args, Dict): - raise ValueError("Cannot convert recipe arguments into dictionary") - else: - recipe_args = {} - - # initialize manager and override num epochs if available - self.manager = ( - ScheduledModifierManager.from_yaml(recipe, **recipe_args) - if recipe - else None - ) - if ( - self.manager - and self.manager.max_epochs - and "args" in kwargs - and (hasattr(kwargs["args"], "num_train_epochs")) - ): - kwargs["args"].num_train_epochs = self.manager.max_epochs - - self.loggers = None - if self.recipe is not None: - loggers = [] - if "wandb" in self.args.report_to: - loggers.append(logger.WANDBLogger()) - self.loggers = loggers - - # add disable FP16 callback - self.callback_handler.add_callback(DisableHalfPrecisionCallback(self)) - - def apply_recipes(self, epoch=0.0): - """ - Applies all recipes from checkpoint_recipes. Runs architecture changing - modifiers to prepare model for state dict loading - """ - # get state dict before recipe application - org_state_dict = self.model.state_dict() - - # apply any checkpoint recipes - for checkpoint_recipe in self.checkpoint_recipes: - if checkpoint_recipe is not None: - ScheduledModifierManager.from_yaml(checkpoint_recipe).apply(self.model) - - # init current training recipe - if self.manager is not None: - self.manager.initialize( - self.model, - epoch=epoch, - distillation_teacher=self.teacher, - loggers=self.loggers, - ) - - # if model structure changed, load in new params from state dict - new_state_dict = self.model.state_dict() - new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] - - if os.path.isdir(self.model_name_or_path): - if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): - archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - new_params_to_init = [p for p in new_params if p in state_dict.keys()] - if new_params_to_init: - # parameters from dict are dependent on recipe - ( - _, - missing_keys, - unexpected_keys, - _, - ) = self.model._load_state_dict_into_model( - self.model, - state_dict, - self.model_name_or_path, - _fast_init=False, - ) - if missing_keys or unexpected_keys: - raise RuntimeError( - "Unexpected or missing keys detected when applying " - f"recipes to models\nMissing keys: {missing_keys}\n" - f"Unexpected keys: {unexpected_keys}\n" - ) - - def create_optimizer(self): - """ - Create optimizer customized using SparseML - """ - super().create_optimizer() - if not self.recipe or not self.manager: - return - total_batch_size = ( - self.args.per_device_train_batch_size - * self.args._n_gpu - * self.args.gradient_accumulation_steps - ) - steps_per_epoch = math.ceil(len(self.train_dataset) / total_batch_size) - if hasattr(self, "scaler"): - self.scaler = self.manager.modify( - self.model, - self.optimizer, - steps_per_epoch=steps_per_epoch, - wrap_optim=self.scaler, - ) - else: - self.optimizer = ScheduledOptimizer( - self.optimizer, - self.model, - self.manager, - steps_per_epoch=steps_per_epoch, - loggers=self.loggers, - ) - - def create_scheduler(self, num_training_steps: int): - """ - Override LR scheduler if the SparseML manager has LR modifiers, otherwise - set default scheduler - """ - if self.lr_scheduler is not None: - # scheduler already set - return - - if self.manager is not None and self.manager.learning_rate_modifiers: - # allow SparseML to manage LR and set a dummy scheduler - self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - self.optimizer, lambda _: 1.0, -1 - ) - else: - # default scheduler - super().create_scheduler(num_training_steps) - - def qat_active(self, epoch: int): - if self.manager is None or not self.manager.quantization_modifiers: - return False - - qat_start = min( - [mod.start_epoch for mod in self.manager.quantization_modifiers] - ) - - return qat_start < epoch + 1 - - def compute_loss(self, model, inputs, return_outputs=False): - """ - Computing loss using teacher/student distillation - """ - if not self.recipe or self.manager is None or self.teacher is None: - return super().compute_loss(model, inputs, return_outputs=return_outputs) - - student_outputs = model(**inputs) - loss = student_outputs["loss"] - - teacher_inputs = ( - inputs - if not self.teacher_input_keys - else {k: inputs[k] for k in self.teacher_input_keys} - ) - - steps_in_epoch = -1 # Unused - loss = self.manager.loss_update( - loss, - model, - self.optimizer, - self.state.epoch, - steps_in_epoch, - global_step=self.state.global_step, - student_outputs=student_outputs, - teacher_inputs=teacher_inputs, - ) - return (loss, student_outputs) if return_outputs else loss - - def save_model(self, output_dir: Optional[str] = None): - """ - Save model during or after training. Modifiers that change the model - architecture will also be saved - """ - super().save_model(output_dir=output_dir) - if self.manager is not None: - self._save_arch_modifiers(output_dir=output_dir) - - def _save_arch_modifiers(self, output_dir: Optional[str] = None): - """ - Save modifiers that change the model's architecture, which is to be applied - later on whenever the model is loaded - """ - if not self.manager: - return - - output_dir = output_dir if output_dir is not None else self.args.output_dir - output_recipe_file = os.path.join(output_dir, RECIPE_NAME) - saved_mods = [ - mod - for mod in self.manager.modifiers - if isinstance(mod, QuantizationModifier) - or isinstance(mod, LayerPruningModifier) - ] - if saved_mods and os.path.exists(output_recipe_file): - with open(output_recipe_file, "a") as yaml_file: - for mod in saved_mods: - yaml_file.write(str(mod) + "\n\n") - - -class DisableHalfPrecisionCallback(TrainerCallback): - """ - TrainerCallback for disabling FP16 training when QAT training begins - - :param sparseml_trainer: SparseML trainer that will call back into this object - :param args: args to be passed to base TrainerCallback - :param kwargs: key word arguments to be passed to base TrainerCallback - """ - - def __init__(self, sparseml_trainer: SparseMLTrainer, *args, **kwargs): - super().__init__(*args, **kwargs) - self._trainer = sparseml_trainer - - def on_epoch_begin( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - """ - Event called at the beginning of an epoch. Disables - """ - super().on_epoch_begin(args, state, control, **kwargs) - if ( - hasattr(self._trainer, "scaler") - and self._trainer.scaler._enabled - and (self._trainer.qat_active(state.epoch)) - ): - _LOGGER.info("entering QAT phase, disabling FP16 training") - self.scaler._enabled = False diff --git a/tests/sparseml/pytorch/optim/test_modifier.py b/tests/sparseml/pytorch/optim/test_modifier.py index 2ceb55c7dd0..c8486ef445a 100644 --- a/tests/sparseml/pytorch/optim/test_modifier.py +++ b/tests/sparseml/pytorch/optim/test_modifier.py @@ -67,8 +67,9 @@ def initialize_helper( model: Module = None, epoch: float = 0.0, log_initialize: bool = True, + **kwargs, ): - modifier.initialize(model, epoch) + modifier.initialize(model, epoch, **kwargs) if log_initialize: modifier.initialize_loggers([PythonLogger()]) @@ -400,6 +401,7 @@ def test_update_ready( optim_lambda: Callable[[Module], Optimizer], test_epoch: float, # noqa: F811 test_steps_per_epoch: int, # noqa: F811 + **initialize_kwargs, ): modifier = modifier_lambda() model = model_lambda() @@ -408,7 +410,7 @@ def test_update_ready( with pytest.raises(RuntimeError): modifier.update_ready(0.0, test_steps_per_epoch) - self.initialize_helper(modifier, model) + self.initialize_helper(modifier, model, **initialize_kwargs) modifier.enabled = False assert not modifier.update_ready(modifier.start_epoch, test_steps_per_epoch) modifier.enabled = True @@ -430,6 +432,7 @@ def test_scheduled_update( optim_lambda: Callable[[Module], Optimizer], test_epoch: float, # noqa: F811 test_steps_per_epoch: int, # noqa: F811 + **initialize_kwargs, ): modifier = modifier_lambda() model = model_lambda() @@ -438,7 +441,7 @@ def test_scheduled_update( with pytest.raises(RuntimeError): modifier.scheduled_update(model, optimizer, 0.0, test_steps_per_epoch) - self.initialize_helper(modifier, model) + self.initialize_helper(modifier, model, **initialize_kwargs) if modifier.start_epoch <= 0.0: modifier.scheduled_update(model, optimizer, 0.0, test_steps_per_epoch) diff --git a/tests/sparseml/pytorch/optim/test_modifier_distillation.py b/tests/sparseml/pytorch/optim/test_modifier_distillation.py index a9111360680..c88026cdb00 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_distillation.py +++ b/tests/sparseml/pytorch/optim/test_modifier_distillation.py @@ -21,7 +21,7 @@ from torch.nn import Module from torch.optim import Optimizer -from sparseml.pytorch.optim import DistillationModifier, Modifier +from sparseml.pytorch.optim import DistillationModifier, Modifier, ScheduledModifier from tests.sparseml.pytorch.helpers import LinearNet, create_optim_sgd from tests.sparseml.pytorch.optim.test_modifier import ScheduledModifierTest @@ -38,6 +38,12 @@ ] +def _get_fake_batch(model_lambda): + batch_size = 5 + input_shape = model_lambda.layer_descs()[0].input_size + return torch.randn(batch_size, *input_shape) + + @pytest.mark.skipif( os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), reason="Skipping pytorch tests", @@ -46,6 +52,40 @@ @pytest.mark.parametrize("model_lambda", [LinearNet], scope="function") @pytest.mark.parametrize("optim_lambda", [create_optim_sgd], scope="function") class TestDistillationModifierImpl(ScheduledModifierTest): + def test_update_ready( + self, + modifier_lambda: Callable[[], ScheduledModifier], + model_lambda: Callable[[], Module], + optim_lambda: Callable[[Module], Optimizer], + test_epoch: float, # noqa: F811 + test_steps_per_epoch: int, # noqa: F811 + ): + super().test_update_ready( + modifier_lambda, + model_lambda, + optim_lambda, + test_epoch, + test_steps_per_epoch, + distillation_teacher=model_lambda(), + ) + + def test_scheduled_update( + self, + modifier_lambda: Callable[[], ScheduledModifier], + model_lambda: Callable[[], Module], + optim_lambda: Callable[[Module], Optimizer], + test_epoch: float, # noqa: F811 + test_steps_per_epoch: int, # noqa: F811 + ): + super().test_scheduled_update( + modifier_lambda, + model_lambda, + optim_lambda, + test_epoch, + test_steps_per_epoch, + distillation_teacher=model_lambda(), + ) + def test_lifecycle( self, modifier_lambda, @@ -57,37 +97,16 @@ def test_lifecycle( model = model_lambda() optimizer = optim_lambda(model) - self.initialize_helper(modifier, model) + self.initialize_helper(modifier, model, distillation_teacher=model_lambda()) for epoch in range(int(modifier.start_epoch)): assert not modifier.update_ready(epoch, test_steps_per_epoch) assert modifier.update_ready(modifier.start_epoch, test_steps_per_epoch) - modifier.scheduled_update( model, optimizer, modifier.start_epoch, test_steps_per_epoch ) - # test distillation has been applied - # fake forward pass - student_inputs = self._get_fake_batch(model_lambda) - student_outputs = model(student_inputs) - teacher_outputs = student_outputs + 0.5 # fake teacher model's outputs - fake_loss = student_outputs.mean() - updated_loss = modifier.loss_update( - fake_loss, - model, - optimizer, - -1, - test_steps_per_epoch, - student_outputs, - teacher_outputs, - ) - - assert isinstance(updated_loss, torch.Tensor) - assert updated_loss.shape == fake_loss.shape - assert fake_loss.item() != updated_loss.item() - if modifier.end_epoch > modifier.start_epoch: assert not modifier.update_ready( (modifier.start_epoch + modifier.end_epoch) / 2, test_steps_per_epoch @@ -107,27 +126,26 @@ def test_loss_update( model = model_lambda() optimizer = optim_lambda(model) - self.initialize_helper(modifier, model) + self.initialize_helper(modifier, model, distillation_teacher=model_lambda()) - # run fake forward pass and try updating the loss - inputs = self._get_fake_batch(model_lambda) - student_outputs = model(inputs) - new_loss = modifier.loss_update( - test_loss, + # test distillation has been applied + # fake forward pass + student_inputs = _get_fake_batch(model_lambda) + student_outputs = model(student_inputs) + fake_loss = student_outputs.mean() + updated_loss = modifier.loss_update( + fake_loss, model, optimizer, - test_epoch, + modifier.start_epoch, test_steps_per_epoch, student_outputs, - inputs, + student_inputs, ) - assert isinstance(new_loss, Tensor) - - def _get_fake_batch(self, model_lambda): - batch_size = 5 - input_shape = model_lambda.layer_descs()[0].input_size - return torch.randn(batch_size, *input_shape) + assert isinstance(updated_loss, torch.Tensor) + assert updated_loss.shape == fake_loss.shape + assert fake_loss.item() != updated_loss.item() @pytest.mark.skipif(