Skip to content

Dynamic sparse training and pruning algorithms for pytorch

License

Notifications You must be signed in to change notification settings

mklasby/sparsimony

Repository files navigation

sparsimony

CI Pipeline CD Pipeline

Under development.

Example usage:

from sparsimony import rigl

sparsifier = rigl(
   optimizer,
   sparsity=0.5,
   t_end=70000,
)

sparse_config = [
        {"tensor_fqn": f"{fqn}.weight"}
        for fqn, module in model.named_modules()
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d)
    ]

sparsifier.prepare(model, sparse_config)


# init DDP AFTER reparametrization!!
if world_size > 1:
   # Distributed
   model = DistributedDataParallel(model, device_ids=[local_rank])
   model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)


for data, target in train_loader:
    ...
    optimizer.step()
    sparsifier.step()