Skip to content

Commit

Permalink
linting correction
Browse files Browse the repository at this point in the history
  • Loading branch information
cemuyuk committed Jul 17, 2024
1 parent 70d7039 commit a1b5b55
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions sparsimony/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def set(
sparsity=sparsity,
)


def static(
optimizer: torch.optim.Optimizer,
sparsity: float,
Expand Down
3 changes: 1 addition & 2 deletions sparsimony/dst/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sparsimony.utils import get_mask


# TODO - double check if the current default init_method is good to go for static sparsity
# TODO - double check init_method
class StaticMagnitudeSparsifier(DSTMixin, BaseSparsifier):
def __init__(
self,
Expand All @@ -17,7 +17,6 @@ def __init__(
sparsity: float,
init_method: str = "sparse_torch",
):

optimizer = optimizer
self.distribution = distribution
self.sparsity = sparsity
Expand Down
4 changes: 3 additions & 1 deletion sparsimony/pruners/unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def calculate_mask(
_, indices = torch.topk(scores.reshape(-1), k=n_drop, largest=False)
mask = (
mask.reshape(-1)
.scatter(dim=0, index=indices, src=torch.zeros_like(mask.reshape(-1)))
.scatter(
dim=0, index=indices, src=torch.zeros_like(mask.reshape(-1))
)
.reshape(mask.shape)
)
return mask
Expand Down

0 comments on commit a1b5b55

Please sign in to comment.