diff --git a/sparsimony/api.py b/sparsimony/api.py index 971852b..f5a6240 100644 --- a/sparsimony/api.py +++ b/sparsimony/api.py @@ -12,7 +12,7 @@ BaseScheduler, ConstantScheduler, CosineDecayScheduler, - StaticScheduler, + AlwaysTrueScheduler, ) from sparsimony.dst.rigl import RigL from sparsimony.dst.srigl import SRigL, NMSRigL @@ -309,7 +309,7 @@ def srste( **kwargs, ) -> SRSTESparsifier: if scheduler is None: - scheduler = StaticScheduler() + scheduler = AlwaysTrueScheduler() if distribution is None: distribution = UniformNMDistribution(n=n, m=m) return SRSTESparsifier(scheduler, distribution, n, m, decay) diff --git a/sparsimony/distributions/base.py b/sparsimony/distributions/base.py index 00b1e90..e518d5a 100644 --- a/sparsimony/distributions/base.py +++ b/sparsimony/distributions/base.py @@ -141,6 +141,10 @@ def __call__( ) -> List[Dict[str, Any]]: if sparsity in self._cache: return self._cache_loader(sparsity, groups) + self._logger.debug( + f"Sparsity {sparsity} not found in distribution cache." + " Calculating..." + ) if not self.excluded_modules_in_param_count: for layer_idx, layer_config in enumerate(groups): if self._should_exclude( diff --git a/sparsimony/mask_calculators/base.py b/sparsimony/mask_calculators/base.py index 3f075da..11ffe69 100644 --- a/sparsimony/mask_calculators/base.py +++ b/sparsimony/mask_calculators/base.py @@ -5,7 +5,11 @@ import torch -from sparsimony.utils import calculate_per_tile_n_ones, view_tensors_as # noqa +from sparsimony.utils import ( # noqa + calculate_per_tile_n_ones, + view_tensors_as, + timing, +) from .scorers import ABCScorer @@ -108,10 +112,10 @@ def _calculate_mask( *args, **kwargs, ) -> torch.Tensor: + n_ones_per_tile_target = calculate_per_tile_n_ones(mask, sparsity) - n_drop_per_tile = torch.tensor( - [tile.sum().item() - n_ones_per_tile_target for tile in mask], - dtype=torch.int, + n_drop_per_tile = (mask.sum(dim=-1) - n_ones_per_tile_target).to( + torch.int ) if not self._verify_mask_update( n, n_ones_per_tile_target, n_drop_per_tile diff --git a/sparsimony/parametrization/ste_parametrization.py b/sparsimony/parametrization/ste_parametrization.py index c6bf6e0..233041a 100644 --- a/sparsimony/parametrization/ste_parametrization.py +++ b/sparsimony/parametrization/ste_parametrization.py @@ -1,10 +1,9 @@ from typing import Tuple, Any import torch -import torch.nn as nn from torch import autograd -from sparsimony.mask_calculators import MagnitudeScorer, NMPruner +from sparsimony.parametrization.fake_sparsity import FakeSparsity class STE(autograd.Function): @@ -46,20 +45,16 @@ def backward(ctx, grad_outputs: Tuple[torch.Tensor]) -> torch.Tensor: ) -class FakeSparsitySTE(nn.Module): - def __init__(self, n: int = 2, m: int = 4, *args, **kwargs): - super().__init__() - self.sparsity = 1 - n / m +class FakeSparsitySTE(FakeSparsity): + def __init__( + self, mask: torch.Tensor, n: int = 2, m: int = 4, *args, **kwargs + ): + super().__init__(mask) + self.n = n + self.m = m def forward(self, weights): - pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m) - mask = pruner.calculate_mask( - 1 - (self.n / self.m), - torch.ones_like(weights, dtype=torch.bool), - values=weights, - ) - self.mask = mask - return STE.apply(weights, mask) + return STE.apply(weights, self.mask) def __name__(self): return "FakeSparsitySTE" @@ -69,24 +64,23 @@ def sparsity(self): return 1 - (self.n / self.m) -class FakeSparsitySRSTE(nn.Module): +class FakeSparsitySRSTE(FakeSparsity): def __init__( - self, n: int = 2, m: int = 4, decay: float = 2e-4, *args, **kwargs + self, + mask: torch.Tensor, + n: int = 2, + m: int = 4, + decay: float = 2e-4, + *args, + **kwargs ): - super().__init__() + super().__init__(mask) self.n = n self.m = m self.decay = decay def forward(self, weights): - pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m) - mask = pruner.calculate_mask( - 1 - self.n / self.m, - torch.ones_like(weights, dtype=torch.bool), - values=weights, - ) - self.mask = mask - return SRSTE.apply(weights, mask, self.decay) + return SRSTE.apply(weights, self.mask, self.decay) def __name__(self): return "FakeSparsitySRSTE" diff --git a/sparsimony/pruners/ste.py b/sparsimony/pruners/ste.py index 688a0d0..ee243bd 100644 --- a/sparsimony/pruners/ste.py +++ b/sparsimony/pruners/ste.py @@ -1,4 +1,3 @@ -import math import time from typing import Optional, Dict, Any import logging @@ -11,9 +10,9 @@ from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier from sparsimony.distributions.base import BaseDistribution -from sparsimony.utils import get_mask, get_parametrization +from sparsimony.utils import get_mask from sparsimony.schedulers.base import BaseScheduler - +from sparsimony.mask_calculators import MagnitudeScorer, NMPruner from sparsimony.parametrization.ste_parametrization import ( FakeSparsitySRSTE, FakeSparsitySTE, @@ -36,6 +35,10 @@ def __init__( self.m = m self.sparsity = 1 - n / m self.decay = decay + self._logger = logging.getLogger(__name__) + self.prepared_ = False + self._step_count = 0 + self.pruner = NMPruner(MagnitudeScorer, n=self.n, m=self.m) if defaults is None: defaults = {} ste_parametrization = ( @@ -43,9 +46,6 @@ def __init__( ) defaults["parametrization"] = ste_parametrization super().__init__(defaults) - self._logger = logging.getLogger(__name__) - self.prepared_ = False - self._step_count = 0 # @overide @torch.no_grad() @@ -66,6 +66,8 @@ def prepare( # @override def step(self) -> bool: + self._logger.debug("SR-STE step() in prog...") + start = time.time() _topo_updated = False self._step_count += 1 prune_ratio = self.scheduler(self._step_count) @@ -73,16 +75,35 @@ def step(self) -> bool: _topo_updated = True self.distribution((1 - self.n / self.m), self.groups) for config in self.groups: - self._update_mask(**config) + self.update_mask(**config) + self._logger.debug( + f"SR-STE step completd in {time.time() - start} seconds" + ) return _topo_updated # @override def update_mask( - self, module: nn.Module, tensor_name: str, sparsity: float, **kwargs + self, + module: nn.Module, + tensor_name: str, + sparsity: float, + tensor_fqn: str, + **kwargs, ): - parametrization = get_parametrization(module, tensor_name) - new_n = math.floor((1 - sparsity) * self.m) - parametrization.n = new_n + self._logger.debug(f"Updating mask for {tensor_fqn}...") + mask = get_mask(module, tensor_name) + # set all elements to active after optim step and reprune + mask.data = torch.ones_like(mask, dtype=torch.bool) + if sparsity == 0: + return + original_weights = getattr( + module.parametrizations, tensor_name + ).original + mask.data = self.pruner.calculate_mask( + self.sparsity, mask, values=original_weights + ) + self._assert_sparsity_level(mask, sparsity) + self._assert_structure(mask, tensor_fqn) # @override def _prepare(self, *args, **kwargs): @@ -91,10 +112,13 @@ def _prepare(self, *args, **kwargs): module = config["module"] tensor_name = config["tensor_name"] parametrization = config.get("parametrization") + mask = torch.ones_like( + getattr(module, tensor_name), dtype=torch.bool + ) register_parametrization( module, tensor_name, - parametrization(n=self.n, m=self.m, decay=self.decay), + parametrization(mask, n=self.n, m=self.m, decay=self.decay), ) def calculate_global_sparsity(self): @@ -141,3 +165,41 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() + + def _assert_sparsity_level(self, mask: torch.Tensor, sparsity_level: float): + n_ones = mask.count_nonzero() + target_n_ones = int(mask.numel() * (1 - sparsity_level)) + # We ignore off-by-one errors as these will be due to floor ops + if n_ones != target_n_ones and abs(n_ones - target_n_ones) > 1: + # With very large mask tensors, we may have some precision errors + # with exact n_ones. Therefore, we simply log the warning instead of + # raising. + # Also naturally occurs in structured pruning + # TODO: For structured pruning we may wish to calculate + # actual_n_ones based on network topology + self._logger.warning( + f"n_ones actual {n_ones} != n_one target {target_n_ones}" + ) + + def _assert_structure(self, mask, fqn: str) -> None: + if self.n == 2 and self.m == 4: + if mask.shape[1] % 64 != 0: + self._logger.warning( + f"Mask shape is not a multiple of 64, this weight tensor " + "may not work with torch 2:4 semi-structured kernels!\n" + f"Mask shape: {mask.shape} found at {fqn}" + ) + try: + mask_view = mask.view(-1, self.m) + except RuntimeError as e: + self._logger.error(f"fqn: {fqn}") + raise e + ones = torch.count_nonzero(mask_view, dim=-1) + if (ones != self.n).all(): + self._logger.warning( + f"{fqn} mask is not {self.n}:{self.m} pruned! Ones Tensor:\n" + f"{ones}" + ) + raise RuntimeError( + f"N:M Violation found: {ones.unique()} n's in layer {fqn}" + ) diff --git a/sparsimony/schedulers/base.py b/sparsimony/schedulers/base.py index 04c4b20..6854f12 100644 --- a/sparsimony/schedulers/base.py +++ b/sparsimony/schedulers/base.py @@ -32,6 +32,14 @@ def __call__(self, *args, **kwargs): return None +class AlwaysTrueScheduler(BaseScheduler): + def __init__(self, *args, **kwargs): + return + + def __call__(self, *args, **kwargs): + return True + + class ConstantScheduler(BaseScheduler): def __init__( diff --git a/sparsimony/utils.py b/sparsimony/utils.py index 22f6aa2..27a2948 100644 --- a/sparsimony/utils.py +++ b/sparsimony/utils.py @@ -4,6 +4,9 @@ from math import prod from contextlib import contextmanager # noqa import math +from functools import wraps +from time import time + import torch import torch.nn as nn @@ -15,6 +18,26 @@ ) _logger = logging.getLogger(__name__) +global __view_tensors_as_warning_logged +__view_tensors_as_warning_logged = False + + +def timing(f): + _logger = logging.getLogger(__name__) + + @wraps(f) + def wrap(*args, **kw): + ts = time() + result = f(*args, **kw) + te = time() + _logger.debug(f"func: {f.__name__} took {te-ts:.4f} sec") + # _logger.debug( + # "func:%r args:[%r, %r] took: %2.4f sec" + # % (f.__name__, args, kw, te - ts) + # ) + return result + + return wrap def get_mask( @@ -187,7 +210,21 @@ def wrapped_fn(*args, **kwargs) -> torch.Tensor: out = out.view(-1) indx = torch.argwhere(~torch.isnan(out))[:, 0] out = out[indx] - return out.reshape(original_size) + try: + return out.view(original_size) + except Exception as e: + global __view_tensors_as_warning_logged + if not __view_tensors_as_warning_logged: + _logger.warning( + "Had to reshape output from view_tensor_as. Please " + "check if your input tensors to calculate_mask() are " + "contiguous. If so please report this on GitHub. " + "Exception:" + ) + _logger.warning(e) + __view_tensors_as_warning_logged = True + pass + return out.reshape(original_size) return wrapped_fn