Skip to content

Commit

Permalink
[Feature Branch] Quant modifier UX (#2263)
Browse files Browse the repository at this point in the history
* Split WandaPruningModifier and SparseGPTModifier
Make sparsegpt not inherit from wanda modifier
Decouple SparseGPTModifierPyTorch from WandaPruningModifier
Fix docstrings

* Split SparseGPT and GPTQ modifiers (#2272)

* Update OBCQ

* Extract GPTQ Modifier

* [GPTQ Modifier UX] Update tests to use GPTQModifier for obcq style quantization (#2294)

* Update OBCQ

* Extract GPTQ Modifier

* Update test recipes

* GPTQ UX config groups support (#2273)

* Update OBCQ

* Extract GPTQ Modifier

* Update test recipes

* Add config_groups support to GPTQModifier

* mask_structure preservation test (#2284)

* test

* Preserve weight sparsity if greater than threshold

* Add argument to preserve sparsity mask in SPARSEGPT

* fix case when mask is none

* Add test to check mask_structure
- initial mask structure should be preserved
b/w consecutive runs; added test to check this

* Update tensor_follows_mask_structure to check for atleast n zeros

---------

Co-authored-by: Sara Adkins <[email protected]>

* PR comments

---------

Co-authored-by: Sara Adkins <[email protected]>

* Fix default case

* Update test to use new vLLMQuantizationModifier

* Style

---------

Co-authored-by: Sara Adkins <[email protected]>
  • Loading branch information
rahul-tuli and Sara Adkins authored May 22, 2024
1 parent 53541f3 commit c24e97f
Show file tree
Hide file tree
Showing 35 changed files with 1,367 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ initial_sparsity_stage:
sparsity: 0.5
block_size: 128
sequential_update: False
quantize: False
percdamp: 0.01
mask_structure: "0:0"
targets: [
Expand All @@ -24,7 +23,6 @@ next_sparsity_stage:
sparsity: 0.7
block_size: 128
sequential_update: False
quantize: False
percdamp: 0.01
mask_structure: "0:0"
targets: [
Expand Down
154 changes: 79 additions & 75 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from sparseml.core.factory import ModifierFactory
from sparseml.core import Modifier
from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(WandaPruningModifier):
class SparseGPTModifier(Modifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model
Expand All @@ -41,84 +38,91 @@ class SparseGPTModifier(WandaPruningModifier):
- on_finalize
- LayerCompressor.revert_layer_wrappers()
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not to quantize weights during SparseGPT. Set to
True to quantize using an existing quantization modifier, or pass in the
configuration for a quantization modifier if one does not already exist
in the recipe
:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
in the paper https://arxiv.org/pdf/2310.05175
:param owl_m: Number of outliers to use for OWL
:param owl_lmbda: Lambda value to use for OWL
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
during when applying sparsegpt, this becomes useful when starting from a
previously pruned model, defaults to False.
"""

block_size: int = 128
quantize: Union[bool, Dict] = False
sparsity: Union[float, List[float]] = 0.0
sparsity_profile: Optional[str] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
quantization_modifier_: Any = None
preserve_sparsity_mask: bool = False
prunen_: Optional[int] = None
prunem_: Optional[int] = None
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
Initialize the structure of the model for compression.
This modifier does not modifiy the model structure, so this method
is a no-op.
:param state: session state storing input model and calibration data
"""
return True

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
compressible layer names
:precondition: self.model is set and is a `ModifiableModel`
:precondition: The `ModifiableModel` implements a `get_layers`
method
:return: dictionary of modules to compress
"""
if not isinstance(self.model, ModifiableModel):
raise ValueError(
"`self.model` must be a ModifiableModel to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
)

return self.model.get_layers(self.targets)

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

target_layers = list(self.compressible_layers_.keys())

if len(target_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(target_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

def on_finalize(self, state: State, **kwargs):
"""
Nothing to do on finalize, on this level.
Quantization Modifier if any will be finalized in the subclass
:param state: session state storing input model and calibration data
:param kwargs: additional arguments
:return: True
"""
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)
return True
Loading

0 comments on commit c24e97f

Please sign in to comment.