This repository contains the official code used for language modeling experiments in the paper(s):
- Linear Transformers are Secretly Fast Weight Programmers (ICML 2021)
- Going Beyond Linear Transformers with Recurrent Fast Weight Programmers
- ...
More generally, this can be used as a language modeling toolkit in PyTorch to experiment with:
-
Fast Weight Programmers with different update rules and linear attention functions:
- Update rules: "sum" and our "delta" rule (as proposed in our paper; Sec 4.2)
- Linear attention functions: "ELU-based" linear attention, "FAVOR+", "deterministic parameter-free projection (DPFP)"
e.g. some combinations result in well known models:
- Linear Transformers = "sum" update rule + "ELU-based" linear attention
- Performers = "sum" update rule + "FAVOR+"
This repositiory contains two implementations of fast weights.
- Custom cuda kernel (see utils/fast_fast_weight and utils/cuda_fast_weight_layer.py)
- Custom
torch.autograd.Function
(see utils/fast_weight.py)
While we only used the cuda implementation for all our final experiments (faster/much better GPU utilization),
torch.autograd.Function
version can be useful for a quick prototyping with new extensions.
This toolkit requires PyTorch torch
and Ninja ninja
(to compile the cuda kernels).
The experiments for the paper were conducted with Python 3.6 and PyTorch 1.4.0 (note on Aug 24, 2023: the code also works with Python 3.11 and PyTorch 2.0.1+cu117).
More recent versions of PyTorch are not yet well supported by this toolkit which still uses torch.nn.DataParallel
for multi-GPU training.
If you really need to use a more recent version of PyTorch, check the documentation
to use torch.nn.parallel.DistributedDataParallel
instead. We will hopefully fix this soon, but we cannot tell exactly when.
The toolkit supports Weights & Biases for monitoring jobs. If you use it, also install wandb
.
This reposity contains many lines of code taken and adapted from the following sources:
- This reposity was originally forked from the official implementation of Transformer-XL kimiyoung/transformer-xl. The code for Transformer-XL and standard Transformer models, as well as basic functionality needed for language modeling (including adaptive input and output embeddings) and data preparation (WikiText-103, enwik8, ...) is from the corresponding repository.
- For Performers, helper functions from lucidrains/performer-pytorch are used.
- For cuda implementations of our fast weight programmers with the delta rule:
- Code from idiap/fast-transformers is used with minor changes for the sum update rule.
- We modified it to implement our update rule. See comments in code for exact locations and modifications.
Please check files under example_scripts
for general instructions and examples to train and evaluate models.
@inproceedings{schlag2021linear,
title={Linear Transformers Are Secretly Fast Weight Programmers},
author={Imanol Schlag and Kazuki Irie and J\"urgen Schmidhuber},
booktitle={Proc. Int. Conf. on Machine Learning (ICML)},
address = {Virtual only},
month = jul,
year={2021}
}
@article{irie2021going,
title={Going Beyond Linear Transformers with Recurrent Fast Weight Programmers},
author={Kazuki Irie and Imanol Schlag and R\'obert Csord\'as and J\"urgen Schmidhuber},
journal={Preprint arXiv:2106.06295},
year={2021}
}
- The code for synthetic retrieval experiments in the paper "Linear Transformers are Secretly Fast Weight Programmers" (ICML 2021) can be found at ischlag/fast-weight-transformers.
- The full repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers" can be found at: IDSIA/recurrent-fwp
- Jürgen Schmidhuber's AI blog post on Fast Weight Programmers (March 26, 2021).