Skip to content

Commit

Permalink
refactor completed, x100 speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
mklasby committed Oct 30, 2024
1 parent 730e56f commit 71be271
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 44 deletions.
4 changes: 2 additions & 2 deletions sparsimony/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseScheduler,
ConstantScheduler,
CosineDecayScheduler,
StaticScheduler,
AlwaysTrueScheduler,
)
from sparsimony.dst.rigl import RigL
from sparsimony.dst.srigl import SRigL, NMSRigL
Expand Down Expand Up @@ -309,7 +309,7 @@ def srste(
**kwargs,
) -> SRSTESparsifier:
if scheduler is None:
scheduler = StaticScheduler()
scheduler = AlwaysTrueScheduler()
if distribution is None:
distribution = UniformNMDistribution(n=n, m=m)
return SRSTESparsifier(scheduler, distribution, n, m, decay)
4 changes: 4 additions & 0 deletions sparsimony/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __call__(
) -> List[Dict[str, Any]]:
if sparsity in self._cache:
return self._cache_loader(sparsity, groups)
self._logger.debug(
f"Sparsity {sparsity} not found in distribution cache."
" Calculating..."
)
if not self.excluded_modules_in_param_count:
for layer_idx, layer_config in enumerate(groups):
if self._should_exclude(
Expand Down
12 changes: 8 additions & 4 deletions sparsimony/mask_calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import torch

from sparsimony.utils import calculate_per_tile_n_ones, view_tensors_as # noqa
from sparsimony.utils import ( # noqa
calculate_per_tile_n_ones,
view_tensors_as,
timing,
)
from .scorers import ABCScorer


Expand Down Expand Up @@ -108,10 +112,10 @@ def _calculate_mask(
*args,
**kwargs,
) -> torch.Tensor:

n_ones_per_tile_target = calculate_per_tile_n_ones(mask, sparsity)
n_drop_per_tile = torch.tensor(
[tile.sum().item() - n_ones_per_tile_target for tile in mask],
dtype=torch.int,
n_drop_per_tile = (mask.sum(dim=-1) - n_ones_per_tile_target).to(
torch.int
)
if not self._verify_mask_update(
n, n_ones_per_tile_target, n_drop_per_tile
Expand Down
44 changes: 19 additions & 25 deletions sparsimony/parametrization/ste_parametrization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Tuple, Any

import torch
import torch.nn as nn
from torch import autograd

from sparsimony.mask_calculators import MagnitudeScorer, NMPruner
from sparsimony.parametrization.fake_sparsity import FakeSparsity


class STE(autograd.Function):
Expand Down Expand Up @@ -46,20 +45,16 @@ def backward(ctx, grad_outputs: Tuple[torch.Tensor]) -> torch.Tensor:
)


class FakeSparsitySTE(nn.Module):
def __init__(self, n: int = 2, m: int = 4, *args, **kwargs):
super().__init__()
self.sparsity = 1 - n / m
class FakeSparsitySTE(FakeSparsity):
def __init__(
self, mask: torch.Tensor, n: int = 2, m: int = 4, *args, **kwargs
):
super().__init__(mask)
self.n = n
self.m = m

def forward(self, weights):
pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m)
mask = pruner.calculate_mask(
1 - (self.n / self.m),
torch.ones_like(weights, dtype=torch.bool),
values=weights,
)
self.mask = mask
return STE.apply(weights, mask)
return STE.apply(weights, self.mask)

def __name__(self):
return "FakeSparsitySTE"
Expand All @@ -69,24 +64,23 @@ def sparsity(self):
return 1 - (self.n / self.m)


class FakeSparsitySRSTE(nn.Module):
class FakeSparsitySRSTE(FakeSparsity):
def __init__(
self, n: int = 2, m: int = 4, decay: float = 2e-4, *args, **kwargs
self,
mask: torch.Tensor,
n: int = 2,
m: int = 4,
decay: float = 2e-4,
*args,
**kwargs
):
super().__init__()
super().__init__(mask)
self.n = n
self.m = m
self.decay = decay

def forward(self, weights):
pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m)
mask = pruner.calculate_mask(
1 - self.n / self.m,
torch.ones_like(weights, dtype=torch.bool),
values=weights,
)
self.mask = mask
return SRSTE.apply(weights, mask, self.decay)
return SRSTE.apply(weights, self.mask, self.decay)

def __name__(self):
return "FakeSparsitySRSTE"
Expand Down
86 changes: 74 additions & 12 deletions sparsimony/pruners/ste.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import time
from typing import Optional, Dict, Any
import logging
Expand All @@ -11,9 +10,9 @@
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier

from sparsimony.distributions.base import BaseDistribution
from sparsimony.utils import get_mask, get_parametrization
from sparsimony.utils import get_mask
from sparsimony.schedulers.base import BaseScheduler

from sparsimony.mask_calculators import MagnitudeScorer, NMPruner
from sparsimony.parametrization.ste_parametrization import (
FakeSparsitySRSTE,
FakeSparsitySTE,
Expand All @@ -36,16 +35,17 @@ def __init__(
self.m = m
self.sparsity = 1 - n / m
self.decay = decay
self._logger = logging.getLogger(__name__)
self.prepared_ = False
self._step_count = 0
self.pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m)
if defaults is None:
defaults = {}
ste_parametrization = (
FakeSparsitySTE if decay is None else FakeSparsitySRSTE
)
defaults["parametrization"] = ste_parametrization
super().__init__(defaults)
self._logger = logging.getLogger(__name__)
self.prepared_ = False
self._step_count = 0

# @overide
@torch.no_grad()
Expand All @@ -66,23 +66,44 @@ def prepare(

# @override
def step(self) -> bool:
self._logger.debug("SR-STE step() in prog...")
start = time.time()
_topo_updated = False
self._step_count += 1
prune_ratio = self.scheduler(self._step_count)
if prune_ratio is not None:
_topo_updated = True
self.distribution((1 - self.n / self.m), self.groups)
for config in self.groups:
self._update_mask(**config)
self.update_mask(**config)
self._logger.debug(
f"SR-STE step completd in {time.time() - start} seconds"
)
return _topo_updated

# @override
def update_mask(
self, module: nn.Module, tensor_name: str, sparsity: float, **kwargs
self,
module: nn.Module,
tensor_name: str,
sparsity: float,
tensor_fqn: str,
**kwargs,
):
parametrization = get_parametrization(module, tensor_name)
new_n = math.floor((1 - sparsity) * self.m)
parametrization.n = new_n
self._logger.debug(f"Updating mask for {tensor_fqn}...")
mask = get_mask(module, tensor_name)
# set all elements to active after optim step and reprune
mask.data = torch.ones_like(mask, dtype=torch.bool)
if sparsity == 0:
return
original_weights = getattr(
module.parametrizations, tensor_name
).original
mask.data = self.pruner.calculate_mask(
self.sparsity, mask, values=original_weights
)
self._assert_sparsity_level(mask, sparsity)
self._assert_structure(mask, tensor_fqn)

# @override
def _prepare(self, *args, **kwargs):
Expand All @@ -91,10 +112,13 @@ def _prepare(self, *args, **kwargs):
module = config["module"]
tensor_name = config["tensor_name"]
parametrization = config.get("parametrization")
mask = torch.ones_like(
getattr(module, tensor_name), dtype=torch.bool
)
register_parametrization(
module,
tensor_name,
parametrization(n=self.n, m=self.m, decay=self.decay),
parametrization(mask, n=self.n, m=self.m, decay=self.decay),
)

def calculate_global_sparsity(self):
Expand Down Expand Up @@ -141,3 +165,41 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return self.__str__()

def _assert_sparsity_level(self, mask: torch.Tensor, sparsity_level: float):
n_ones = mask.count_nonzero()
target_n_ones = int(mask.numel() * (1 - sparsity_level))
# We ignore off-by-one errors as these will be due to floor ops
if n_ones != target_n_ones and abs(n_ones - target_n_ones) > 1:
# With very large mask tensors, we may have some precision errors
# with exact n_ones. Therefore, we simply log the warning instead of
# raising.
# Also naturally occurs in structured pruning
# TODO: For structured pruning we may wish to calculate
# actual_n_ones based on network topology
self._logger.warning(
f"n_ones actual {n_ones} != n_one target {target_n_ones}"
)

def _assert_structure(self, mask, fqn: str) -> None:
if self.n == 2 and self.m == 4:
if mask.shape[1] % 64 != 0:
self._logger.warning(
f"Mask shape is not a multiple of 64, this weight tensor "
"may not work with torch 2:4 semi-structured kernels!\n"
f"Mask shape: {mask.shape} found at {fqn}"
)
try:
mask_view = mask.view(-1, self.m)
except RuntimeError as e:
self._logger.error(f"fqn: {fqn}")
raise e
ones = torch.count_nonzero(mask_view, dim=-1)
if (ones != self.n).all():
self._logger.warning(
f"{fqn} mask is not {self.n}:{self.m} pruned! Ones Tensor:\n"
f"{ones}"
)
raise RuntimeError(
f"N:M Violation found: {ones.unique()} n's in layer {fqn}"
)
8 changes: 8 additions & 0 deletions sparsimony/schedulers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def __call__(self, *args, **kwargs):
return None


class AlwaysTrueScheduler(BaseScheduler):
def __init__(self, *args, **kwargs):
return

def __call__(self, *args, **kwargs):
return True


class ConstantScheduler(BaseScheduler):

def __init__(
Expand Down
39 changes: 38 additions & 1 deletion sparsimony/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from math import prod
from contextlib import contextmanager # noqa
import math
from functools import wraps
from time import time


import torch
import torch.nn as nn
Expand All @@ -15,6 +18,26 @@
)

_logger = logging.getLogger(__name__)
global __view_tensors_as_warning_logged
__view_tensors_as_warning_logged = False


def timing(f):
_logger = logging.getLogger(__name__)

@wraps(f)
def wrap(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
_logger.debug(f"func: {f.__name__} took {te-ts:.4f} sec")
# _logger.debug(
# "func:%r args:[%r, %r] took: %2.4f sec"
# % (f.__name__, args, kw, te - ts)
# )
return result

return wrap


def get_mask(
Expand Down Expand Up @@ -187,7 +210,21 @@ def wrapped_fn(*args, **kwargs) -> torch.Tensor:
out = out.view(-1)
indx = torch.argwhere(~torch.isnan(out))[:, 0]
out = out[indx]
return out.reshape(original_size)
try:
return out.view(original_size)
except Exception as e:
global __view_tensors_as_warning_logged
if not __view_tensors_as_warning_logged:
_logger.warning(
"Had to reshape output from view_tensor_as. Please "
"check if your input tensors to calculate_mask() are "
"contiguous. If so please report this on GitHub. "
"Exception:"
)
_logger.warning(e)
__view_tensors_as_warning_logged = True
pass
return out.reshape(original_size)

return wrapped_fn

Expand Down

0 comments on commit 71be271

Please sign in to comment.