Skip to content

Commit

Permalink
[cherry-pick] transformers refactor (#538)
Browse files Browse the repository at this point in the history
* Refactor of Transformers SparseML CLI and integrations (#536)

* Refactor of Transformers SparseML CLI and integrations

* Refactor export.py to use new pathways, fix make quality

* Update src/sparseml/optim/manager.py

Co-authored-by: Rahul Tuli <[email protected]>

* Update src/sparseml/transformers/utils/model.py

Co-authored-by: Rahul Tuli <[email protected]>

* fixes from review

* fixes from review and testing

* bug fixes and logging

* bug fixes for export and distillation

* review fixes, quality fixes, style fixes

* fix dependency issue

* fix distillation tests

* fix distillation tests

* fix distillation tests

* fill in docs and update style

* fix issue with distillation improperly updating students inputs

* fix quality

* Update src/sparseml/pytorch/optim/modifier_distillation.py

* add in better logging for missing and unexpected keys in model reload for transformers trainer

* fix logging for transformers export

Co-authored-by: Rahul Tuli <[email protected]>

* Fix model load bug and add logging to catch potential future issues (#537)

* Fix model load bug and add logging to catch potential future issues

* initial migration to generalize module sparsification information

* propagate ModuleSparsificationInfo

* report type of input tensors in export.py

* minor bug fixes

* ModuleSparsificationInfo docs

* export onnx bugfix

* bug fixes

* make style

* bug fix for quantization

* revert to use ScheduledOptimizer due to bug with torch LambdaLR

* remove language_modeling script

* add end model sparsification log

Co-authored-by: Benjamin <[email protected]>

Co-authored-by: Mark Kurtz <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2022
1 parent d1b0622 commit b09c6d0
Show file tree
Hide file tree
Showing 43 changed files with 2,059 additions and 1,721 deletions.
58 changes: 34 additions & 24 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@

_dev_deps = [
"beautifulsoup4==4.9.3",
"black>=20.8b1",
"flake8>=3.8.3",
"isort>=5.7.0",
"black==21.5b2",
"flake8==3.9.2",
"isort==5.8.0",
"m2r2~=0.2.7",
"myst-parser~=0.14.0",
"rinohtype>=0.4.2",
"sphinx>=3.4.0",
"sphinx-copybutton>=0.3.0",
"sphinx-markdown-tables>=0.0.15",
"sphinx-multiversion==0.2.4",
"sphinx-pydantic>=0.1.0",
"sphinx-rtd-theme>=0.5.0",
"rinohtype~=0.4.2",
"sphinx~=3.5.0",
"sphinx-copybutton~=0.3.0",
"sphinx-markdown-tables~=0.0.15",
"sphinx-multiversion~=0.2.4",
"sphinx-pydantic~=0.1.0",
"sphinx-rtd-theme~=0.5.0",
"wheel>=0.36.2",
"pytest>=6.0.0",
"pytest-mock>=3.6.1",
"flaky>=3.0.0",
"pytest~=6.2.0",
"pytest-mock~=3.6.0",
"flaky~=3.7.0",
"sphinx-rtd-theme",
]

Expand Down Expand Up @@ -112,25 +112,35 @@ def _setup_extras() -> Dict:
}


_transformers_entry_point_template = (
"sparseml.transformers.train.{task}=sparseml.transformers.train.{task}:main"
)


def _setup_entry_points() -> Dict:
return {
entry_points = {
"console_scripts": [
# sparsification
"sparseml.benchmark=sparseml.benchmark.info:_main",
"sparseml.framework=sparseml.framework.info:_main",
"sparseml.sparsification=sparseml.sparsification.info:_main",
_transformers_entry_point_template.format(task="question_answering"),
_transformers_entry_point_template.format(task="text_classification"),
_transformers_entry_point_template.format(task="token_classification"),
_transformers_entry_point_template.format(task="language_modeling"),
"sparseml.transformers.export_onnx=sparseml.transformers.utils.export:main",
]
}

# transformers integration
for task in [
"question_answering",
"text_classification",
"token_classification",
]:
entry_points["console_scripts"].extend(
[
f"sparseml.transformers.{task}=sparseml.transformers.{task}:main",
f"sparseml.transformers.train.{task}=sparseml.transformers.{task}:main",
]
)

entry_points["console_scripts"].append(
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
)

return entry_points


def _setup_long_description() -> Tuple[str, str]:
return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
Expand Down
7 changes: 4 additions & 3 deletions src/sparseml/keras/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
Also handles loading modifiers from yaml files
"""

from typing import List, Union
from typing import Any, Dict, List, Optional, Union

from tensorflow import Tensor

from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
from sparseml.keras.utils.compat import keras
from sparseml.keras.utils.logger import KerasLogger
from sparseml.optim import BaseManager, load_recipe_yaml_str
from sparseml.optim import BaseManager, load_recipe_yaml_str, parse_recipe_variables
from sparsezoo.objects import Recipe


Expand All @@ -41,7 +41,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
def from_yaml(
file_path: Union[str, Recipe],
add_modifiers: List[Modifier] = None,
**recipe_variables,
recipe_variables: Optional[Union[Dict[str, Any], str]] = None,
):
"""
Convenience function used to create the manager of multiple modifiers from a
Expand All @@ -59,6 +59,7 @@ def from_yaml(
with (i.e. num_epochs, init_lr)
:return: ScheduledModifierManager() created from the recipe file
"""
recipe_variables = parse_recipe_variables(recipe_variables)
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
modifiers = Modifier.load_list(yaml_str)
if add_modifiers:
Expand Down
60 changes: 59 additions & 1 deletion src/sparseml/optim/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
Helper functions for base Modifier and Manger utilities
"""

import json
import re
from typing import Any, Dict, Tuple, Union
from contextlib import suppress
from typing import Any, Dict, Optional, Tuple, Union

import yaml

Expand All @@ -32,6 +34,7 @@
"rewrite_recipe_yaml_string_with_classes",
"update_recipe_variables",
"evaluate_recipe_yaml_str_equations",
"parse_recipe_variables",
]


Expand Down Expand Up @@ -137,6 +140,61 @@ def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str:
return pattern.sub(r"!\g<class_name>", updated_yaml_str)


def parse_recipe_variables(
recipe_variables: Optional[Union[Dict[str, Any], str]] = None
) -> Dict[str, Any]:
"""
Parse input recipe_variables into a dictionary that can be used to overload
variables at the root of a recipe.
Supports dictionaries as well as parsing a string in either json or
csv key=value format
:param recipe_variables: the recipe_variables string or dictionary to parse
for variables used with overloading recipes
:return: the parsed recipe variables
"""
if not recipe_variables:
return {}

if isinstance(recipe_variables, Dict):
return recipe_variables

if not isinstance(recipe_variables, str):
raise ValueError(
f"recipe_args must be a string for parsing, given {recipe_variables}"
)

# assume json first, try and parse
with suppress(Exception):
recipe_variables = json.loads(recipe_variables)
return recipe_variables

# assume csv, and standardize to format key=val
orig_recipe_variables = recipe_variables
recipe_vars_str = recipe_variables.replace(":", "=")
recipe_variables = {}
for arg_val in recipe_vars_str.split(","):
vals = arg_val.split("=")
if len(vals) != 2:
raise ValueError(
"Improper key=val given in csv for recipe variables with value "
f"{arg_val} in {orig_recipe_variables}"
)
key = vals[0].strip()
if any(char in key for char in ["{", "!", "=", "}"]):
raise ValueError(
"Improper key given in csv for recipe variables with value "
f"{key} in {orig_recipe_variables}"
)
val = vals[1].strip()
with suppress(Exception):
# check if val should be a number, otherwise fall back on string
val = float(val)
recipe_variables[key] = val

return recipe_variables


def update_recipe_variables(recipe_yaml_str: str, variables: Dict[str, Any]) -> str:
"""
:param recipe_yaml_str: YAML string of a SparseML recipe
Expand Down
89 changes: 72 additions & 17 deletions src/sparseml/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@
from functools import cmp_to_key
from typing import List

from sparseml.optim.modifier import (
BaseModifier,
BaseObject,
BaseScheduled,
ModifierProp,
)
from sparseml.optim.modifier import BaseModifier, BaseObject, ModifierProp
from sparseml.sparsification.types import SparsificationTypes
from sparseml.utils import clean_path, create_parent_dirs


Expand All @@ -42,7 +38,7 @@ class BaseManager(BaseObject):
:param modifiers: the modifiers to wrap
"""

def __init__(self, modifiers: List[BaseScheduled], **kwargs):
def __init__(self, modifiers: List[BaseModifier], **kwargs):
super().__init__(**kwargs)
# sort modifiers by when they start and end so that later modifiers
# can overwrite in a deterministic order such as when initializing
Expand All @@ -57,44 +53,88 @@ def __del__(self):
def __str__(self) -> str:
return "\n".join(self.to_string_lines())

def __eq__(self, compare: object) -> bool:
return str(self) == str(compare)

@ModifierProp(serializable=False)
def modifiers(self) -> List[BaseScheduled]:
def modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe
"""
return self._modifiers

@ModifierProp(serializable=False)
def epoch_modifiers(self) -> List[BaseScheduled]:
def epoch_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that modify the
epoch range
"""
return [mod for mod in self._modifiers if "Epoch" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.epoch in mod.sparsification_types
]

@ModifierProp(serializable=False)
def learning_rate_modifiers(self) -> List[BaseScheduled]:
def learning_rate_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that modify the
LearningRate schedule
"""
return [mod for mod in self._modifiers if "LearningRate" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.learning_rate in mod.sparsification_types
]

@ModifierProp(serializable=False)
def pruning_modifiers(self) -> List[BaseScheduled]:
def pruning_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
model sparsity
"""
return [mod for mod in self._modifiers if "Pruning" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.pruning in mod.sparsification_types
]

@ModifierProp(serializable=False)
def quantization_modifiers(self) -> List[BaseScheduled]:
def quantization_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
model quantization
"""
return [mod for mod in self._modifiers if "Quantization" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.quantization in mod.sparsification_types
]

@ModifierProp(serializable=False)
def distillation_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
Distillation
"""
return [
mod
for mod in self._modifiers
if SparsificationTypes.distillation in mod.sparsification_types
]

@ModifierProp(serializable=False)
def structured_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
structure changes to a model such as layer pruning, fitler pruning,
and quantization
"""
return [
mod
for mod in self._modifiers
if SparsificationTypes.structured in mod.sparsification_types
]

@ModifierProp(serializable=False)
def min_epochs(self) -> int:
Expand Down Expand Up @@ -154,7 +194,7 @@ def to_string_lines(self) -> List[str]:

return yaml_str_lines

def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]:
def modifiers_to_string_lines(self, modifiers: List[BaseModifier]) -> List[str]:
"""
:param modifiers: the modifiers to convert into string / yaml representation
for within the manage
Expand All @@ -176,3 +216,18 @@ def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]
yaml_str_lines.append("")

return yaml_str_lines

def qat_active(self, epoch: float) -> bool:
"""
:param epoch: the epoch to check if quantization aware training will be
active during
:return: True if quantization aware training will be active at the start
of or within the given epoch, False otherwise
"""
quant_modifiers = self.quantization_modifiers

return (
min(mod.start_epoch for mod in quant_modifiers) < epoch + 1
if quant_modifiers
else False
)
8 changes: 8 additions & 0 deletions src/sparseml/optim/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml

from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
from sparseml.sparsification.types import SparsificationTypes
from sparseml.utils import ALL_TOKEN, validate_str_iterable


Expand Down Expand Up @@ -466,6 +467,13 @@ def __repr__(self):
self.props(only_serializable=False, format_repr=True),
)

@ModifierProp(serializable=False)
def sparsification_types(self) -> List[SparsificationTypes]:
"""
:return: the sparsification types this modifier instance will apply
"""
return []

@ModifierProp(serializable=True)
def log_types(self) -> Union[None, str, List[str]]:
"""
Expand Down
Loading

0 comments on commit b09c6d0

Please sign in to comment.