Skip to content

Commit

Permalink
various hotfixes from dev branch
Browse files Browse the repository at this point in the history
  • Loading branch information
mklasby committed Jul 16, 2024
1 parent fe1167c commit d7bb3f5
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 9 deletions.
2 changes: 2 additions & 0 deletions sparsimony/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(
self.excluded_types = excluded_types
if excluded_mod_name_regexs is None:
excluded_mod_name_regexs = []
elif isinstance(excluded_mod_name_regexs, str):
excluded_mod_name_regexs = [excluded_mod_name_regexs]
self.excluded_mod_name_regexs = excluded_mod_name_regexs
self._logger = logging.getLogger(__name__)
self._cache: Dict[float, List[float]] = dict()
Expand Down
3 changes: 3 additions & 0 deletions sparsimony/dst/rigl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
sparsity: float = 0.5,
grown_weights_init: float = 0.0,
init_method: Optional[str] = "grad_flow",
*args,
**kwargs,
):
self.scheduler = scheduler
self.distribution = distribution
Expand Down Expand Up @@ -78,6 +80,7 @@ def _step(self) -> bool:
self._step_count += 1
prune_ratio = self.scheduler(self._step_count)
if prune_ratio is not None:
self._logger.info(f"Updating topology at step {self._step_count}")
self._distribute_sparsity(self.sparsity)
for config in self.groups:
config["prune_ratio"] = prune_ratio
Expand Down
14 changes: 8 additions & 6 deletions sparsimony/dst/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(
sparsity: float = 0.5,
grown_weights_init: float = 0.0,
init_method: Optional[str] = "grad_flow",
*args,
**kwargs,
):
self.scheduler = scheduler
self.distribution = distribution
Expand All @@ -47,7 +49,7 @@ def prune_mask(
mask: torch.Tensor,
weights: torch.Tensor,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
mask.data = UnstructuredMagnitudePruner.calculate_mask(
prune_ratio, mask, weights
Expand Down Expand Up @@ -77,7 +79,8 @@ def grow_mask(
# Overwrite old mask
mask.data = new_mask.data

def _step(self) -> None:
def _step(self) -> bool:
_topo_updated = False
self._step_count += 1
prune_ratio = self.scheduler(self._step_count)
if prune_ratio is not None:
Expand All @@ -86,18 +89,17 @@ def _step(self) -> None:
config["prune_ratio"] = prune_ratio
self.update_mask(**config)
self._broadcast_masks()
_topo_updated = True
self._step_count += 1

def _assert_sparsity_level(self, mask, sparsity_level):
assert mask.sum() == int(mask.numel() * sparsity_level)
return _topo_updated

def update_mask(
self,
module: nn.Module,
tensor_name: str,
sparsity: float,
prune_ratio: float,
**kwargs
**kwargs,
):
mask = get_mask(module, tensor_name)
if sparsity == 0:
Expand Down
13 changes: 11 additions & 2 deletions sparsimony/schedulers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,26 @@ def __init__(
assert t_grow < delta_t

def next_step_update(self, last_step: int) -> bool:
if last_step + 1 > self.t_end:
return False
if last_step % self.delta_t == (self.delta_t - self.t_grow):
# start filling buffers for grow step
return True
# elif (last_step + 1) % self.delta_t == self.t_grow:
# elif last_step % self.delta_t == 0:
# return True
# elif (
# last_step % self.delta_t == self.t_grow and last_step > self.delta_t # noqa
# ):
# # Prune next step (need plus one?)
# return True
return False

def __call__(self, step: int) -> Optional[float]:
if step > self.t_end:
return None
if step % self.delta_t == 0:
return -self.pruning_ratio # Grow by prune ratio
elif step % (self.delta_t + self.t_grow) == 0:
elif step % self.delta_t == self.t_grow and step > self.delta_t:
return self.pruning_ratio # Prune by prune ratio
else:
return None
78 changes: 77 additions & 1 deletion tests/sparsimony/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import numpy as np

from sparsimony.schedulers.base import ConstantScheduler, CosineDecayScheduler
from sparsimony.schedulers.base import (
ConstantScheduler,
CosineDecayScheduler,
SoftMemoryBoundScheduler,
)


class TestSchedulers:
Expand Down Expand Up @@ -32,3 +36,75 @@ def test_cosine_decay_scheduler_call(self, cosine_decay_scheduler):
assert cosine_decay_scheduler(3) is None
# Test for step after t_end
assert cosine_decay_scheduler(11) is None


def id_fn(args):
pruning_ratio, t_end, delta_t, t_grow = args
return (
f"pruning_ratio: {pruning_ratio} t_end: {t_end} delta_t: {delta_t} "
f"t_grow: {t_grow}"
)


class TestSoftMemoryBound:

@pytest.fixture(params=[(0.3, 200, 100, 20)], ids=id_fn)
def scheduler(self, request):
pruning_ratio, t_end, delta_t, t_grow = request.param
_scheduler = SoftMemoryBoundScheduler(
pruning_ratio=pruning_ratio,
t_end=t_end,
delta_t=delta_t,
t_grow=t_grow,
)
yield _scheduler
del _scheduler

def test_next_step_update(self, scheduler):
t_end = scheduler.t_end
delta_t = scheduler.delta_t
t_grow = scheduler.t_grow

update_cycles = t_end // delta_t
start_buffer_steps = [
delta_t * n - t_grow for n in list(range(1, update_cycles + 1))
]
# grow_next_steps = [
# (delta_t * n) - 1 for n in list(range(1, update_cycles + 1))
# ]
# prune_next_steps = [
# delta_t * n + t_grow for n in list(range(1, update_cycles + 1))
# ]
for step in range(1, t_end + 1):
# Test cases for next_step_update
if step in start_buffer_steps:
assert (
scheduler.next_step_update(step) is True
) # Grow next step
# elif step in grow_next_steps:
# print(step)
# assert scheduler.next_step_update(step) is True
# elif step in prune_next_steps:
# assert scheduler.next_step_update(step) is True
else:
assert scheduler.next_step_update(step) is False

def test_call(self, scheduler):
pruning_ratio = scheduler.pruning_ratio
t_end = scheduler.t_end
delta_t = scheduler.delta_t
t_grow = scheduler.t_grow

update_cycles = t_end // delta_t
grow_steps = [delta_t * n for n in list(range(1, update_cycles + 1))]
prune_steps = [
delta_t * n + t_grow for n in list(range(1, update_cycles + 1))
]
for step in range(1, t_end + 1):
# Test cases for next_step_update
if step in grow_steps:
assert scheduler(step) == -pruning_ratio # Grow next step
elif step in prune_steps:
assert scheduler(step) == pruning_ratio
else:
assert scheduler(step) is None

0 comments on commit d7bb3f5

Please sign in to comment.