Skip to content

Commit

Permalink
data helper added for global pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
mklasby committed Aug 9, 2024
1 parent 69010d1 commit e8c58fd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 88 deletions.
103 changes: 59 additions & 44 deletions sparsimony/dst/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Tuple
from typing import Dict, Any, List
import copy
import logging
import torch
Expand All @@ -26,7 +26,7 @@ def __init__(
self,
optimizer: torch.optim.Optimizer,
random_mask_init: bool = True,
global_pruining: bool = False,
global_pruning: bool = False,
*args,
**kwargs,
):
Expand All @@ -45,7 +45,7 @@ def __init__(
)
self.optimizer = optimizer
self.random_mask_init = random_mask_init
self.global_pruning = global_pruining
self.global_pruning = global_pruning
self._step_count = 0
self._logger = logging.getLogger(__name__)
self.prepared_ = False
Expand Down Expand Up @@ -202,7 +202,7 @@ def _assert_sparsity_level(self, mask: torch.Tensor, sparsity_level: float):
# with exact n_ones. Therefore, we simply log the warning instead of
# raising.
self._logger.warning(
f"n_ones actual{n_ones} != n_one target {actual_n_ones}"
f"n_ones actual {n_ones} != n_one target {actual_n_ones}"
)

# @override
Expand Down Expand Up @@ -296,51 +296,66 @@ def _is_replica(
return True
return False

# TODO: Move following to mixin interface for global pruning?
def _global_reshape_and_assign(
self,
concantenated_mask: torch.Tensor,
original_shapes: List[Tuple[int]],
original_numels: List[int],
) -> None:
for idx, config in enumerate(self.groups):
module = config["module"]
tensor_name = config["tensor_name"]
mask = get_mask(module, tensor_name)
stride_start = sum(original_numels[:idx])
stride_end = sum(original_numels[: idx + 1])
shape = original_shapes[idx]
mask.data = concantenated_mask[stride_start:stride_end].reshape(
shape
)
# TODO: Move to global pruner Mixin?
def _global_step(self, *args, **kwargs) -> None:
raise NotImplementedError(
"self.global_prune is True but _global_step has not been "
f"implemented for {self.__class__.__name__}."
)

def _global_init_prune(self) -> None:
global_data_helper = GlobalPruningDataHelper(self.groups)
if self.random_mask_init:
global_data_helper.masks.data = (
UnstructuredRandomPruner.calculate_mask(
self.sparsity, global_data_helper.masks
)
)
else:
# use pruning criterion
self.prune_mask(
self.sparsity,
global_data_helper.masks,
global_data_helper.sparse_weights,
)
self._assert_sparsity_level(global_data_helper.masks, self.sparsity)
global_data_helper.reshape_and_assign_masks()


class GlobalPruningDataHelper:

def __init__(self, groups: List[Dict[str, Any]]):
self.groups = groups
original_weights = []
masks = []
sparse_weights = []
original_shapes = []
original_numels = []
masks = []
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
masks.append(get_mask(module, tensor_name))
original_weights.append(get_original_tensor(module, tensor_name))
sparse_weights.append(getattr(module, tensor_name))
original_shapes = [t.shape for t in masks]
original_numels = [t.numel() for t in masks]
original_weights = torch.concat(original_weights).flatten()
masks = torch.concat(masks).flatten()
sparse_weights = torch.concat(sparse_weights).flatten()
if self.random_mask_init:
masks.data = UnstructuredRandomPruner.calculate_mask(
self.sparsity, masks
mask = get_mask(module, tensor_name)
original_shapes.append(mask.shape)
original_numels.append(mask.numel())
masks.append(mask.flatten())
original_weights.append(
get_original_tensor(module, tensor_name).flatten()
)
else:
# use pruning criterion
self.prune_mask(self.sparsity, masks, sparse_weights)
self._assert_sparsity_level(masks, self.sparsity)
self._global_reshape_and_assign(masks, original_shapes, original_numels)

def _global_step(self, *args, **kwargs) -> None:
raise NotImplementedError(
"self.global_prune is True but _global_step has not been "
f"implemented for {self.__class__.__name__}."
)
sparse_weights.append(getattr(module, tensor_name).flatten())
self.original_weights = torch.concat(original_weights)
self.sparse_weights = torch.concat(sparse_weights)
self.masks = torch.concat(masks)
self.original_shapes = original_shapes
self.original_numels = original_numels

def reshape_and_assign_masks(
self,
) -> None:
for idx, config in enumerate(self.groups):
module = config["module"]
tensor_name = config["tensor_name"]
mask = get_mask(module, tensor_name)
stride_start = sum(self.original_numels[:idx])
stride_end = sum(self.original_numels[: idx + 1])
shape = self.original_shapes[idx]
mask.data = self.masks[stride_start:stride_end].reshape(shape)
32 changes: 11 additions & 21 deletions sparsimony/dst/gmp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any, List, Tuple
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
Expand All @@ -7,8 +7,8 @@
from sparsimony.distributions.base import BaseDistribution
from sparsimony.schedulers.base import BaseScheduler
from sparsimony.parametrization.fake_sparsity import FakeSparsity
from sparsimony.utils import get_mask, get_parametrization, get_original_tensor
from sparsimony.dst.base import DSTMixin
from sparsimony.utils import get_mask
from sparsimony.dst.base import DSTMixin, GlobalPruningDataHelper
from sparsimony.pruners.unstructured import (
UnstructuredMagnitudePruner,
)
Expand Down Expand Up @@ -112,21 +112,11 @@ def _assert_sparsity_level(self, mask: torch.Tensor, sparsity_level: float):
)

def _global_step(self) -> None:
original_weights = []
masks = []
sparse_weights = []
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
masks.append(get_mask(module, tensor_name))
original_weights.append(get_original_tensor(module, tensor_name))
sparse_weights.append(getattr(module, tensor_name))
original_shapes = [t.shape for t in masks]
original_numels = [t.numel() for t in masks]
original_weights = torch.concat(original_weights).flatten()
masks = torch.concat(masks).flatten()
sparse_weights = torch.concat(sparse_weights).flatten()
dense_grads = torch.concat(dense_grads).flatten()
self.prune_mask(self.sparsity, masks, sparse_weights)
self._assert_sparsity_level(masks, self.sparsity)
self._global_reshape_and_assign(masks, original_shapes, original_numels)
global_data_helper = GlobalPruningDataHelper(self.groups)
self.prune_mask(
self.sparsity,
global_data_helper.masks,
global_data_helper.sparse_weights,
)
self._assert_sparsity_level(global_data_helper.masks, self.sparsity)
global_data_helper.reshape_and_assign_masks()
45 changes: 22 additions & 23 deletions sparsimony/dst/rigl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any, Tuple
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier

from sparsimony.distributions.base import BaseDistribution
from sparsimony.schedulers.base import BaseScheduler
from sparsimony.parametrization.fake_sparsity import FakeSparsityDenseGradBuffer
from sparsimony.utils import get_mask, get_parametrization, get_original_tensor
from sparsimony.dst.base import DSTMixin
from sparsimony.utils import get_mask, get_parametrization
from sparsimony.dst.base import DSTMixin, GlobalPruningDataHelper
from sparsimony.pruners.unstructured import (
UnstructuredMagnitudePruner,
UnstructuredGradientGrower,
Expand Down Expand Up @@ -157,25 +157,24 @@ def update_mask(
self._assert_sparsity_level(mask, sparsity)

def _global_step(self, prune_ratio: float) -> None:
original_weights = []
masks = []
sparse_weights = []
global_data_helper = GlobalPruningDataHelper(self.groups)
dense_grads = []
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
masks.append(get_mask(module, tensor_name))
original_weights.append(get_original_tensor(module, tensor_name))
sparse_weights.append(getattr(module, tensor_name))
dense_grads.append(self._get_dense_grads(**config))
original_shapes = [t.shape for t in masks]
original_numels = [t.numel() for t in masks]
original_weights = torch.concat(original_weights).flatten()
masks = torch.concat(masks).flatten()
sparse_weights = torch.concat(sparse_weights).flatten()
dense_grads = torch.concat(dense_grads).flatten()
target_sparsity = self.get_sparsity_from_prune_ratio(masks, prune_ratio)
self.prune_mask(target_sparsity, masks, sparse_weights)
self.grow_mask(self.sparsity, masks, original_weights, dense_grads)
self._assert_sparsity_level(masks, self.sparsity)
self._global_reshape_and_assign(masks, original_shapes, original_numels)
dense_grads.append(self._get_dense_grads(**config).flatten())
dense_grads = torch.concat(dense_grads)
target_sparsity = self.get_sparsity_from_prune_ratio(
global_data_helper.masks, prune_ratio
)
self.prune_mask(
target_sparsity,
global_data_helper.masks,
global_data_helper.sparse_weights,
)
self.grow_mask(
self.sparsity,
global_data_helper.masks,
global_data_helper.original_weights,
dense_grads,
)
self._assert_sparsity_level(global_data_helper.masks, self.sparsity)
global_data_helper.reshape_and_assign_masks()

0 comments on commit e8c58fd

Please sign in to comment.