Skip to content

Commit

Permalink
Clean up global RigL sparsities to be exact (#2159)
Browse files Browse the repository at this point in the history
* clean up rigl sparsities

* Remove flaky from  test_modifier_pruning_rigl.py

---------

Co-authored-by: Alexandre Marques <[email protected]>
Co-authored-by: abhinavnmagic <[email protected]>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent dead8b5 commit 5107108
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,27 @@ def cosine_schedule(t: float, t_max: float, init_value: float, end_value: float)
)


def threshold_fraction(tensor: Tensor, fraction: float) -> None:
def threshold_fraction(
tensor: Tensor, fraction: float, set_to_max: bool = False
) -> None:
"""
A function returning the tensor with all but topk fraction
elements set to 0.
:param tensor: the input tensor
:param fraction: fraction of nonzero elements
:param fraction: fraction of zero elements
"""
lookup_idx = round(fraction * tensor.numel())
lookup_idx = round((1 - fraction) * tensor.numel())
if lookup_idx == 0:
return tensor
threshold, _ = torch.kthvalue(tensor.reshape(-1), k=lookup_idx)
return torch.where(tensor > threshold, 1.0, 0.0)
return torch.zeros_like(tensor)
tensor_shape = tensor.shape
vals, idx = tensor.reshape(-1).topk(lookup_idx, largest=True)
topk = torch.zeros_like(tensor.reshape(-1))
if set_to_max:
topk[idx] = torch.finfo(tensor.dtype).max
else:
topk[idx] = vals
return topk.reshape(tensor_shape)


@PyTorchModifierYAML()
Expand Down Expand Up @@ -104,8 +112,8 @@ class RigLPruningModifier(BaseGradualPruningModifier):
| update_frequency: 4.0
| num_grads: 100
| params: ["re:.*weight"]
| global_sparsity: False
| leave_enabled: True
| global_sparsity: True
| mask_type: unstructured
| sparsity_strategy: "erdos_renyi_kernel"
| init_update_fraction: 0.3
Expand All @@ -124,6 +132,8 @@ class RigLPruningModifier(BaseGradualPruningModifier):
will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
and Linear layers' weights. If a sparsity to param mapping is defined by
final_sparsity, then params should be set to []
:param global_sparsity: set True to enable global pruning. If False, pruning will
be layer-wise. Must be set to False, as global sparsity is not supported yet.
:param momentum_buffer_reset: set True to reset momentum buffer
for pruned weights at every optimizer step, so that reintroduced
weights start with an empty momentum buffer.
Expand All @@ -134,8 +144,6 @@ class RigLPruningModifier(BaseGradualPruningModifier):
RigL modifier supports only 'unstructured'
:param num_grads: Number of grads to be collected by the grad sampler for
recomputing the mask.
:param global_sparsity: set True to enable global pruning. If False, pruning will
be layer-wise. Default is True
:param sparsity_strategy: String to define the sparsity distribution. Following
the original paper one can select one of the 3 options:
[uniform, erdos_renyi, erdos_renyi_kernel].
Expand All @@ -157,28 +165,30 @@ def __init__(
params: Union[str, List[str]],
num_grads: int = 1,
leave_enabled: bool = True,
global_sparsity: bool = False,
momentum_buffer_reset: bool = True,
global_sparsity: bool = True,
mask_type: str = "unstructured",
sparsity_strategy: str = "erdos_renyi_kernel",
init_update_fraction: float = 0.3,
grad_sampler_kwargs: Optional[Dict[str, Any]] = {},
**kwargs,
):
self._sparsity_strategy = sparsity_strategy
super().__init__(
params=params,
final_sparsity=final_sparsity,
init_sparsity=final_sparsity,
start_epoch=start_epoch,
end_epoch=end_epoch,
global_sparsity=global_sparsity,
global_sparsity=False,
update_frequency=update_frequency,
leave_enabled=leave_enabled,
parent_class_kwarg_names=[],
**kwargs,
)
# self._sparsity_distribution = self._scorer.get_sparsity_distribution()
self._mask_type = mask_type
self._sparsity_strategy = sparsity_strategy
# self._sparsity_strategy = sparsity_strategy
self._momentum_buffer_reset = momentum_buffer_reset
self._init_update_fraction = init_update_fraction
self._grad_sampler_kwargs = grad_sampler_kwargs
Expand All @@ -201,14 +211,13 @@ def _validate(self):
), f"{self._mask_type} mask_type not supported"

if self._global_sparsity:
raise ValueError("global sparsity is not supported for RigL.")
else:
assert self._sparsity_strategy in (
"erdos_renyi",
"erdos_renyi_kernel",
), "Global sparsity supports only `erdos_renyi`, `erdos_renyi_kernel`"
else:
assert (
self._sparsity_strategy == "uniform"
), "Uniform sparsity supports only `uniform`"
"uniform",
), "erdos_renyi, erdos_renyi_kernel, and uniform sparsity are supported."

# Override te optimizer_post_step method to reset the momentum of pruned weights.
# TODO: this implementation has some dependencies that may be better handled
Expand Down Expand Up @@ -331,8 +340,10 @@ def get_applied_sparsity_for_epoch(
:param steps_per_epoch: number of steps in each epoch
:return: sparsity level that should be applied (always final_sparsity)
"""
_LOGGER.info(f"RigL applied sparsity {self._final_sparsity}")
return [self._final_sparsity for _ in range(len(self.module_masks.layers))]

self._sparsity_distribution = self._scorer.get_sparsity_distribution()
_LOGGER.info(f"RigL applied sparsity {self._sparsity_distribution}")
return self._sparsity_distribution

def initialize(
self,
Expand Down Expand Up @@ -563,13 +574,12 @@ def get_param_score(self, param: Tensor, param_grad: Tensor, param_sparsity: flo
# Of the existing mask, we keep the top 1-param_sparsity-mask_update_fraction
# elements by magnitude.
magn_score = threshold_fraction(
magn_score, param_sparsity + mask_update_fraction
magn_score, (param_sparsity + mask_update_fraction), set_to_max=True
)
# We fill in the mask by also adding the mask_update_fraction weights
# of the ones that are currently masked, using the gradient magnitude
# as the criterion.
grad_score = param_grad.abs() * magn_score.eq(0)
grad_score = threshold_fraction(grad_score, 1 - mask_update_fraction)
# For the rest of the unmasked weights, we use the gradient magnitude
# as the criterion. These cannot be larger than the magnitude scores,
# since those are set to the largest possible value.
grad_score = (param_grad.abs()) * magn_score.eq(0)
score = magn_score + grad_score
return score

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from typing import Any, Dict, Optional

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -75,7 +76,6 @@ def dataloader_builder(**kwargs: Optional[Dict[str, Any]]):
[
lambda: RigLPruningModifier(
final_sparsity=0.9,
global_sparsity=True,
sparsity_strategy="erdos_renyi_kernel",
start_epoch=2.0,
end_epoch=5.0,
Expand All @@ -87,7 +87,6 @@ def dataloader_builder(**kwargs: Optional[Dict[str, Any]]):
),
lambda: RigLPruningModifier(
final_sparsity=0.7,
global_sparsity=False,
sparsity_strategy="uniform",
start_epoch=2.0,
end_epoch=5.0,
Expand All @@ -101,7 +100,6 @@ def dataloader_builder(**kwargs: Optional[Dict[str, Any]]):
lambda: RigLPruningModifier(
params=["seq.fc1.weight", "seq.fc2.weight"],
final_sparsity=0.5,
global_sparsity=True,
sparsity_strategy="erdos_renyi",
momentum_buffer_reset=False,
start_epoch=2.0,
Expand Down Expand Up @@ -176,15 +174,22 @@ def test_lifecycle(
if not isinstance(applied_sparsities, list):
applied_sparsities = [applied_sparsities]

if not isinstance(modifier.init_sparsity, str):
if (
not isinstance(modifier.init_sparsity, str)
) and modifier._sparsity_strategy == "uniform":
assert all(
applied_sparsity == modifier.init_sparsity
for applied_sparsity in applied_sparsities
)
else:
assert len(modifier._init_sparsity) == len(modifier.module_masks.layers)
total_zeroes = 0
total_params = 0
for idx, param in enumerate(modifier.module_masks.params_data):
assert modifier._init_sparsity[idx] == tensor_sparsity(param).item()
total_zeroes += tensor_sparsity(param).item() * param.numel()
total_params += param.numel()
assert np.isclose(
modifier._init_sparsity, total_zeroes / total_params, atol=1e-4
)

last_sparsities = applied_sparsities

Expand Down Expand Up @@ -217,14 +222,13 @@ def test_lifecycle(
modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)

def _test_final_sparsity_applied():
final_sparsities = (
[modifier.final_sparsity]
if isinstance(modifier.final_sparsity, float)
else modifier.final_sparsity
)
assert all(
sparsity in final_sparsities for sparsity in modifier.applied_sparsity
)
total_zeroes = 0
total_params = 0
for idx, param in enumerate(modifier.module_masks.params_data):
total_zeroes += tensor_sparsity(param).item() * param.numel()
total_params += param.numel()
# RigL can induce additional sparsity from repeated training.
assert total_zeroes / total_params >= modifier.final_sparsity

_test_final_sparsity_applied()

Expand Down Expand Up @@ -293,7 +297,6 @@ def test_rigl_pruning_yaml(params, init_sparsity, final_sparsity):
start_epoch = 5.0
end_epoch = 15.0
update_frequency = 1.0
global_sparsity = True
momentum_buffer_reset = False
sparsity_strategy = "erdos_renyi"
num_grads = 64
Expand All @@ -307,7 +310,6 @@ def test_rigl_pruning_yaml(params, init_sparsity, final_sparsity):
update_frequency: {update_frequency}
params: {params}
momentum_buffer_reset: {momentum_buffer_reset}
global_sparsity: {global_sparsity}
sparsity_strategy: {sparsity_strategy}
mask_type: {mask_type}
num_grads: {num_grads}
Expand All @@ -324,7 +326,6 @@ def test_rigl_pruning_yaml(params, init_sparsity, final_sparsity):
end_epoch=end_epoch,
update_frequency=update_frequency,
params=params,
global_sparsity=global_sparsity,
sparsity_strategy=sparsity_strategy,
momentum_buffer_reset=momentum_buffer_reset,
mask_type=mask_type,
Expand All @@ -342,11 +343,6 @@ def test_rigl_pruning_yaml(params, init_sparsity, final_sparsity):
== str(serialized_modifier.final_sparsity)
== str(obj_modifier.final_sparsity)
)
assert (
str(yaml_modifier.global_sparsity)
== str(serialized_modifier.global_sparsity)
== str(obj_modifier.global_sparsity)
)
assert (
str(yaml_modifier.sparsity_strategy)
== str(serialized_modifier.sparsity_strategy)
Expand Down

0 comments on commit 5107108

Please sign in to comment.