-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Quantization Modifer and Reloading (#2246)
* initial commit * update setup.py * Update setup.py * fix setup.py * move all config to sparsetensors * cleanup class name and comments * initial implementation untested * fixing issues * add test script * update perplexity test * refactor to compressed-tensors * rename sparsetensors * update setup * Sa/model reload (#2250) * working reload * sparsegpt * cleanup * refactor tests * only run oneshot once * all tests passing * remove unused config * reset models on each parameterize * style * bring back SparsityConfigMetadata * Update setup.py Co-authored-by: Rahul Tuli <[email protected]> * add more comparisons, tighten threshold * use wikitext for perplexity * update setup * fix import problem * fix clearml test * compressed-tensors are transformers dep * address PR comments * can't repeat freeze * UX pr comments * quality * shape consistency * address PR comments --------- Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: Rahul Tuli <[email protected]> Co-authored-by: George Ohashi <[email protected]>
- Loading branch information
1 parent
1bad1fb
commit f7cb678
Showing
16 changed files
with
732 additions
and
24 deletions.
There are no files selected for viewing
15 changes: 15 additions & 0 deletions
15
integrations/huggingface-transformers/finetuning/example_single_gpu_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
compute_environment: LOCAL_MACHINE | ||
debug: false | ||
distributed_type: 'NO' | ||
enable_cpu_affinity: false | ||
gpu_ids: 0 | ||
machine_rank: 0 | ||
main_training_function: main | ||
num_machines: 1 | ||
num_processes: 1 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_env: [] | ||
tpu_use_cluster: false | ||
tpu_use_sudo: false | ||
use_cpu: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
|
||
from .base import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Optional | ||
|
||
from pydantic import Field | ||
|
||
from compressed_tensors.quantization import ( | ||
QuantizationConfig, | ||
QuantizationScheme, | ||
QuantizationStatus, | ||
) | ||
from sparseml.core import Event, Modifier | ||
|
||
|
||
__all__ = ["vLLMQuantizationModifier"] | ||
|
||
|
||
class vLLMQuantizationModifier(Modifier): | ||
""" | ||
Enables post training quantization (PTQ) and quantization aware training (QAT) for a | ||
given module or its submodules. After calibration (PTQ) or the start epoch (QAT), | ||
the specified module(s) forward pass will emulate quantized execution and the | ||
modifier will be enabled until training is completed. | ||
:param config_groups: dictionary specifying quantization schemes to apply to target | ||
modules. Modules not matching a scheme target will NOT be quantized. | ||
:param ignore: optional list of module class names or submodule names to not | ||
quantize even if they match a target in config_groups. Defaults to empty list. | ||
:param disable_quantization_observer_epoch: Epoch to disable updates to the module | ||
quantization observers. At this point, quantized weights and zero points will | ||
not be updated. Leave None to not disable observers during QAT. Default is None | ||
:param num_calibration_steps: Number of steps to run post training calibration for. | ||
When None, the entire calibration_dataloader is used | ||
""" | ||
|
||
config_groups: Dict[str, QuantizationScheme] | ||
ignore: List[str] = Field(default_factory=list) | ||
disable_quantization_observer_epoch: Optional[float] = None | ||
num_calibration_steps: Optional[int] = None | ||
|
||
def create_init_config(self) -> QuantizationConfig: | ||
return QuantizationConfig( | ||
config_groups=self.config_groups, | ||
quantization_status=QuantizationStatus.INITIALIZED, | ||
ignore=self.ignore, | ||
) | ||
|
||
def calculate_disable_observer_epoch(self) -> float: | ||
""" | ||
Get the epoch at which we want to disable to quantization observer | ||
:return epoch to disable at, or -1 if it is not set | ||
""" | ||
return ( | ||
self.disable_quantization_observer_epoch | ||
if self.disable_quantization_observer_epoch is not None | ||
else -1 | ||
) | ||
|
||
def check_should_disable_observer(self, event: Event) -> bool: | ||
""" | ||
Given the current index, determine if we should disable the observer | ||
:param event: Event to get index from | ||
:return: True if observer should be disabled, False otherwise | ||
""" | ||
disable_epoch = self.calculate_disable_observer_epoch() | ||
if disable_epoch == -1: | ||
return False | ||
if event.current_index >= disable_epoch: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import Any | ||
|
||
from torch.nn import Module | ||
|
||
from compressed_tensors.quantization import ( | ||
apply_quantization_config, | ||
freeze_module_quantization, | ||
set_module_for_calibration, | ||
) | ||
from sparseml.core import Event, EventType, State | ||
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier | ||
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward | ||
|
||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier): | ||
""" | ||
PyTorch specific implementation of vLLMQuantizationModifier | ||
Enables post training quantization (PTQ) and quantization aware training (QAT) for a | ||
given module or its submodules. After calibration (PTQ) or the start epoch (QAT), | ||
the specified module(s) forward pass will emulate quantized execution and the | ||
modifier will be enabled until training is completed. | ||
:param config_groups: dictionary specifying quantization schemes to apply to target | ||
modules. Modules not matching a scheme target will NOT be quantized. | ||
:param ignore: optional list of module class names or submodule names to not | ||
quantize even if they match a target in config_groups. Defaults to empty list. | ||
:param disable_quantization_observer_epoch: Epoch to disable updates to the module | ||
quantization observers. At this point, quantized weights and zero points will | ||
not be updated. Leave None to not disable observers during QAT. Default is None | ||
:param num_calibration_steps: Number of steps to run post training calibration for. | ||
When None, the entire calibration_dataloader is used | ||
""" | ||
|
||
calibration_dataloader_: Any = None | ||
calibration_function_: Any = None | ||
|
||
def on_initialize_structure(self, state: State, **kwargs): | ||
module = state.model.model | ||
self._apply_modifier_to_model(module) | ||
module.apply(freeze_module_quantization) | ||
|
||
def on_initialize(self, state: State, **kwargs) -> bool: | ||
if self.end and self.end != -1: | ||
raise ValueError( | ||
"end_epoch is disabled for QuantizationModifier and can only be set to" | ||
" -1 or None. Given {}".format(self.end) | ||
) | ||
|
||
self.calibration_dataloader_ = state.data.calib | ||
module = state.model.model | ||
|
||
# intialize quantization in appropriate modules | ||
self._apply_modifier_to_model(module) | ||
|
||
if self.calculate_start() == -1: # one-shot | ||
module.apply(set_module_for_calibration) | ||
self._calibrate_if_possible(module) | ||
module.apply(freeze_module_quantization) | ||
|
||
return True | ||
|
||
def on_finalize(self, state: State, **kwargs) -> bool: | ||
return True | ||
|
||
def on_start(self, state: State, event: Event, **kwargs): | ||
module = state.model.model | ||
module.apply(set_module_for_calibration) | ||
|
||
def on_update(self, state: State, event: Event, **kwargs): | ||
if event.type_ == EventType.BATCH_START: | ||
if self.check_should_disable_observer(event): | ||
module = state.model.model | ||
module.apply(freeze_module_quantization) | ||
|
||
def on_end(self, state: State, event: Event, **kwargs): | ||
module = state.model.model | ||
module.apply(freeze_module_quantization) | ||
|
||
def on_event(self, state: State, event: Event, **kwargs): | ||
pass | ||
|
||
def _apply_modifier_to_model(self, model: Module): | ||
modifier_as_config = self.create_init_config() | ||
apply_quantization_config(model, modifier_as_config) | ||
|
||
def _calibrate_if_possible(self, module: Module): | ||
if self.num_calibration_steps == 0 and self.calibration_dataloader_: | ||
_LOGGER.warning( | ||
f"num_calibration_steps is {self.num_calibration_steps}." | ||
f"Calibration data loader will not be used." | ||
) | ||
elif self.num_calibration_steps and not self.calibration_dataloader_: | ||
raise ValueError( | ||
f"num_calibration_steps is {self.num_calibration_steps}. " | ||
"Calibration data loader is not set. Pass a " | ||
"calibration_data_loader with initialize(...) method." | ||
) | ||
|
||
elif not self.calibration_dataloader_: | ||
return | ||
|
||
self._calibrate(module) | ||
|
||
def _calibrate(self, module: Module): | ||
class_name = self.__class__.__name__.replace("PyTorch", "") | ||
_LOGGER.info( | ||
f"Running {class_name} calibration with " | ||
f"{len(self.calibration_dataloader_)} samples..." | ||
) | ||
|
||
module_training = module.training | ||
module.eval() | ||
|
||
run_calibration_forward( | ||
module, | ||
self.calibration_dataloader_, | ||
self.num_calibration_steps, | ||
self.calibration_function_, | ||
) | ||
|
||
if module_training: | ||
module.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.