Customizable Fastai+PyTorch implementation of sparse model training methods (SET, SNFS, RigL).
pip install fastsparse
This network implements the following sparse algorithms:
Abbr. | Sparse Algorithm | in FastSparse | Notes |
---|---|---|---|
static sparsity baseline | omit DynamicSparseTrainingCallback |
||
SET | Sparse Evolutionary Training (Jan 2019) | DynamicSparseTrainingCallback(**SET_presets) |
|
SNFS | Sparse Networks From Scratch (Jul 2019) | DynamicSparseTrainingCallback(**SNFS_presets) |
*redistribution not implemented |
RigL | Rigged Lottery (Nov 2019) | DynamicSparseTrainingCallback(**RigL_presets) |
*Authors of the RigL paper demonstrate that using SNFS + Erdos-Renyi-Kernel distribution - redistribution outperforms SNFS + uniform sparsity + redistribution (at least on the measured benchmarks).
With just 4 additional lines of code, you can train your model using the latest dynamic sparse training techniques. This example achieves >99% accuracy on MNIST using a ResNet34 with only 1% of the weights.
# (0) install the library
# ! pip install fastsparse
from fastai.vision.all import *
# (1) import this package
import fastsparse as sparse
path = untar_data(URLs.MNIST)
dls = ImageDataLoaders.from_folder(path, 'training', 'testing')
learn = cnn_learner(dls, resnet34, metrics=error_rate, pretrained=False)
# (2) sparsify initial model + enforce masks
sparse_hooks = sparse.sparsify_model(learn.model,
model_sparsity=0.99,
sparse_f=sparse.erdos_renyi_sparsity)
# (3) schedule dynamic mask updates
cbs = [sparse.DynamicSparseTrainingCallback(**sparse.SNFS_presets,
batches_per_update=32)]
learn.fit_one_cycle(5, cbs=cbs)
# (4) remove hooks that enforce masks
sparse_hooks.remove()
Simply omit the DynamicSparseTrainingCallback
to train a fixed-sparsity model as a baseline.
import torch
from torchvision import models
data = ...
model = ...
opt = ...
opt = DynamicSparseTrainingOptimizerWrapper(model, opt, **RigL_kwargs)
### Modified training step
# sparse_opt.step(...) will determine whether to:
# (A) take a regular opt step, or
# (B) update network connectivity
def sparse_train_step(model, xb, yb, loss_func, sparse_opt, step, pct_train):
preds = model(xb)
loss = loss_func(preds, yb)
loss.backward()
sparse_opt.step(step, pct_train)
sparse_opt.zero_grad()
Here is an example of saving a model and reloading it to resume training.
from fastai.vision.all import *
from fastsparse import *
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
sparse_hooks = sparsify_model(learn.model, model_sparsity=0.9, sparse_f=erdos_renyi_sparsity)
dst_kwargs = {**SNFS_presets, **{'batches_per_update': 8}}
cbs = DynamicSparseTrainingCallback(**dst_kwargs)
learn.fit_flat_cos(5, cbs=cbs)
# (0) save model as usual (masks are stored automatically)
save_model('sparse_tiny_mnist', learn.model, learn.opt)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.334119 | 0.679619 | 0.505007 | 00:03 |
1 | 0.271645 | 0.555170 | 0.848355 | 00:02 |
2 | 0.237115 | 0.072088 | 0.978541 | 00:02 |
3 | 0.220553 | 0.044927 | 0.987124 | 00:02 |
4 | 0.174585 | 0.006496 | 1.000000 | 00:02 |
### manually restart notebook ###
# (1) then recreate learner as usual
from fastai.vision.all import *
from fastsparse import *
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
# (2) re-sparsify model (this adds the masks to the parameters)
sparse_hooks = sparsify_model(learn.model, model_sparsity=0.9, sparse_f=erdos_renyi_sparsity) # <-- initial sparsity + enforce masks
# (3) load model as usual
load_model('sparse_tiny_mnist', learn.model, learn.opt)
# (5) check validation loss & accuracy to verify we've loaded it successfully
val_loss, val_acc = learn.validate()
print(f'validation loss: {val_loss}, validation accuracy: {val_acc}')
# (4) optionally, continue training; otherwise remove sparsity-preserving hooks
sparse_hooks.remove()
/home/dc/anaconda3/envs/fastai/lib/python3.8/site-packages/fastai/learner.py:53: UserWarning: Could not load the optimizer state.
if with_opt: warn("Could not load the optimizer state.")
validation loss: 0.006496043410152197, validation accuracy: 1.0
Authors of the Rigged Lottery paper hypothesize that the effectiveness of using the gradient magnitude for determining which connections to grow is partly due to their large batch size (4096 for ImageNet). Those without access to multi-gpu clusters can achieve effective batch sizes of this size by using fastai's GradientAccumulation
callback, which has been tested to be compatible with this package's DynamicSparseTrainingCallback
.
Dynamic sparse training algorithms work by modifying the network connectivity during training, dropping some weights and allowing others to regrow. By default, network connectivity is modified at the end of each epoch. When training with few epochs, however, there will be few chances to explore which weights to connect. To update more frequently, in DynamicSparseTrainingCallback
, set batches_per_update
to a smaller # of batches than occur in one training epoch. Varying the number of batches per update trades off the frequency of updates with stability in making good updates.
There are many ways to implement and test your own dynamic sparse algorithms using FastSparse.
Define your own initial sparsity distribution by setting sparsify_method
in sparsify_model
to a custom function. For example, this function (included in library) will keep the first layer dense and set the remaining layers to a fixed sparsity.
def first_layer_dense_uniform(params:list, model_sparsity:float):
sparsities = [1.] + [model_sparsity] * (len(params) - 1)
return sparsities
While published papers like SNFS and RigL refer to 'drop criterion', this library implements the reverse, a 'keep criterion'. This is a function that returns a score for each weight, where the largest M
scores will be and M
is determined by the decay schedule. For example, both Sparse Networks From Scratch and Rigged Lottery both use the magnitude of the weights (in FastSparse: weight_magnitude
).
This can easily be customized in FastSparse by defining your own keep score function:
def custom_keep_scoring_function(param, opt):
score = ...
assert param.shape == score.shape
return score
Then pass your custom function into the sparse training callback:
DynamicSparseTrainingCallback(..., keep_score_f=custom_keep_scoring_function)
The grow criterion is a function that returns a score for each weight, where the largest N
scores will be and N
is determined by the decay schedule. For example, Sparse Networks From Scrath grows weights according to the momentum of the gradient, while Rigged Lottery uses the magnitude of the gradient (in FastSparse, gradient_momentum
and gradient_magnitude
respectively).
def custom_grow_scoring_function(param, opt):
score = ...
assert param.shape == score.shape
return score
Then pass your custom function into the sparse training callback:
DynamicSparseTrainingCallback(..., grow_score_f=custom_grow_scoring_function)
In machine learning, is very easy for seemingly insignificant differences in algorithmic implementation to have a noticeable impact on final results. Therefore, this section compares results from this implementation to results reported in published papers.
TODO...
Here's what's going on.
When you run sparsify_model(learn.model, 0.9)
, this adds sparse masks and add pre_forward hooks to enforce masks on weights during forward pass.
By default, a uniform sparsity distribution is used. Change the sparsity distribution to Erdos-Renyi with
sparsify_model(learn.model, 0.9, sparse_init_f=erdos_renyi)
, or pass in your custom function (see Customization
To avoid adding pre_forward hooks, use
sparsify_model(learn.model, 0.9, enforce_masks=False)
.
When you add the DynamicSparseTrainingCallback
callback, ... TODO complete section