Skip to content

Commit

Permalink
init mask update
Browse files Browse the repository at this point in the history
  • Loading branch information
cemuyuk committed Jul 17, 2024
1 parent 154b13d commit ce0f408
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ packages= [
python = ">=3.9,<=3.12"
torch=">=2.0"
numpy="1.*"
timm = "^1.0.7"

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0.0"
Expand Down
11 changes: 2 additions & 9 deletions sparsimony/dst/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,10 @@ def _initialize_masks(self):
for config in self.groups:
# Prune to target sparsity for this step
mask = get_mask(config["module"], config["tensor_name"])
original_weights = getattr(
config["module"].parametrizations, config["tensor_name"]
).original

print(f"Original weights shape: {original_weights.shape}")
print(f"Mask shape: {mask.shape}")

weights = getattr(config["module"], config["tensor_name"])
mask.data = UnstructuredMagnitudePruner.calculate_mask(
config["sparsity"], mask, original_weights
config["sparsity"], mask, weights
)
print(f"Mask 1s after pruning: {mask.sum()}")
self._assert_sparsity_level(mask.data, self.sparsity)

def _step(self):
Expand Down

0 comments on commit ce0f408

Please sign in to comment.