Under development.
Example usage:
from sparsimony import rigl
sparsifier = rigl(
optimizer,
sparsity=0.5,
t_end=70000,
)
sparse_config = [
{"tensor_fqn": f"{fqn}.weight"}
for fqn, module in model.named_modules()
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d)
]
sparsifier.prepare(model, sparse_config)
# init DDP AFTER reparametrization!!
if world_size > 1:
# Distributed
model = DistributedDataParallel(model, device_ids=[local_rank])
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for data, target in train_loader:
...
optimizer.step()
sparsifier.step()