Skip to content

Commit

Permalink
magnitude pruner view to reshape and static sparsifier
Browse files Browse the repository at this point in the history
  • Loading branch information
cemuyuk committed Jul 17, 2024
1 parent ee0a18e commit 70d7039
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ black = "^24.2.0"
pre-commit = "^3.6.2"
pytest = "^8.1.1"
pytest-cov = "^5.0.0"
ipykernel = "^6.29.5"


[tool.poetry.urls]
Expand Down
23 changes: 23 additions & 0 deletions sparsimony/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from sparsimony.dst.rigl import RigL
from sparsimony.dst.set import SET
from sparsimony.dst.static import StaticMagnitudeSparsifier


def rigl(
Expand Down Expand Up @@ -80,3 +81,25 @@ def set(
optimizer=optimizer,
sparsity=sparsity,
)

def static(
optimizer: torch.optim.Optimizer,
sparsity: float,
) -> StaticMagnitudeSparsifier:
"""Return StaticMagnitude sparsifier.
Args:
optimizer (torch.optim.Optimizer): Previously initialized optimizer for
training. Used to override the dense gradient buffers for
sparse weights.
sparsity (float): Sparsity level to prune network to.
Returns:
StaticMagnitudeSparsifier: Initialized StaticMagnitude sparsifier.
"""
return StaticMagnitudeSparsifier(
optimizer=optimizer,
distribution=UniformDistribution(),
sparsity=sparsity,
init_method="sparse_torch"
)
92 changes: 92 additions & 0 deletions sparsimony/dst/static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier

from sparsimony.dst.base import DSTMixin
from sparsimony.distributions.base import BaseDistribution
from sparsimony.parametrization.fake_sparsity import FakeSparsity
from sparsimony.pruners.unstructured import UnstructuredMagnitudePruner
from sparsimony.utils import get_mask


# TODO - double check if the current default init_method is good to go for static sparsity
class StaticMagnitudeSparsifier(DSTMixin, BaseSparsifier):
def __init__(
self,
optimizer: torch.optim.Optimizer,
distribution: BaseDistribution,
sparsity: float,
init_method: str = "sparse_torch",
):

optimizer = optimizer
self.distribution = distribution
self.sparsity = sparsity
self.init_method = init_method
defaults = dict(parametrization=FakeSparsity)
super().__init__(optimizer=optimizer, defaults=defaults)

def _initialize_masks(self):
self._distribute_sparsity(self.sparsity)
for config in self.groups:
# Prune to target sparsity for this step
mask = get_mask(config["module"], config["tensor_name"])
original_weights = getattr(
config["module"].parametrizations, config["tensor_name"]
).original

print(f"Original weights shape: {original_weights.shape}")
print(f"Mask shape: {mask.shape}")

mask.data = UnstructuredMagnitudePruner.calculate_mask(
config["sparsity"], mask, original_weights
)
print(f"Mask 1s after pruning: {mask.sum()}")
self._assert_sparsity_level(mask.data, self.sparsity)

def _step(self):
self._step_count += 1
# Basically do nothing to change the mask

def grow_mask(self):
pass

def prune_mask(self):
pass

def update_mask(self):
pass

def __str__(self) -> str:
# TODO: Errors if sparsifier has not been prepared. Fix me
def neuron_is_active(neuron):
return neuron.any()

global_sparsity = self.calculate_global_sparsity().item()
layerwise_sparsity_target = []
layerwise_sparsity_actual = []
active_neurons = []
total_neurons = []
for config in self.groups:
layerwise_sparsity_target.append(config["sparsity"])
mask = get_mask(**config)
layerwise_sparsity_actual.append(
self.calculate_mask_sparsity(mask).item()
)
active_neurons.append(
torch.vmap(neuron_is_active)(mask).sum().item()
)
total_neurons.append(len(mask))
active_vs_total_neurons = []
for a, t in list(zip(active_neurons, total_neurons)):
active_vs_total_neurons.append(f"{a}/{t}")
# TODO: Should list ignored_layers from distribution
return (
f"{self.__class__.__name__}\n"
f"Step No.: {self._step_count}\n"
f"Distribution: {self.distribution.__class__.__name__}\n"
f"Global Sparsity Target: {self.sparsity}\n"
f"Global Sparsity Actual: {global_sparsity}\n"
f"Layerwise Sparsity Targets: {layerwise_sparsity_target}\n"
f"Layerwise Sparsity Actual: {layerwise_sparsity_actual}\n"
f"Active/Total Neurons: {active_vs_total_neurons}"
)
8 changes: 4 additions & 4 deletions sparsimony/pruners/unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def calculate_mask(
mask: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
n_drop = int(mask.sum(dtype=torch.int) * prune_ratio)
n_drop = int(mask.sum() * prune_ratio)
scores = torch.where(
mask == 1, torch.abs(weights), torch.full_like(weights, np.inf)
)
if dist.is_initialized():
dist.all_reduce(scores, dist.ReduceOp.AVG, async_op=False)
_, indices = torch.topk(scores.view(-1), k=n_drop, largest=False)
_, indices = torch.topk(scores.reshape(-1), k=n_drop, largest=False)
mask = (
mask.view(-1)
.scatter(dim=0, index=indices, src=torch.zeros_like(mask.view(-1)))
mask.reshape(-1)
.scatter(dim=0, index=indices, src=torch.zeros_like(mask.reshape(-1)))
.reshape(mask.shape)
)
return mask
Expand Down

0 comments on commit 70d7039

Please sign in to comment.