Skip to content

PyTorch Language Modeling Toolkit for Fast Weight Programmers

License

Notifications You must be signed in to change notification settings

IDSIA/lmtool-fwp

Repository files navigation

PyTorch Language Modeling Toolkit (for Fast Weight Programmers)

This repository contains the official code used for language modeling experiments in the paper(s):

More generally, this can be used as a language modeling toolkit in PyTorch to experiment with:

  • Standard Transformers

  • Transformer-XL

  • Fast Weight Programmers with different update rules and linear attention functions:

    • Update rules: "sum" and our "delta" rule (as proposed in our paper; Sec 4.2)
    • Linear attention functions: "ELU-based" linear attention, "FAVOR+", "deterministic parameter-free projection (DPFP)"

    e.g. some combinations result in well known models:

Fast Weight Implementations

This repositiory contains two implementations of fast weights.

While we only used the cuda implementation for all our final experiments (faster/much better GPU utilization), torch.autograd.Function version can be useful for a quick prototyping with new extensions.

Requirements

This toolkit requires PyTorch torch and Ninja ninja (to compile the cuda kernels).

The experiments for the paper were conducted with Python 3.6 and PyTorch 1.4.0 (note on Aug 24, 2023: the code also works with Python 3.11 and PyTorch 2.0.1+cu117).

More recent versions of PyTorch are not yet well supported by this toolkit which still uses torch.nn.DataParallel for multi-GPU training. If you really need to use a more recent version of PyTorch, check the documentation to use torch.nn.parallel.DistributedDataParallel instead. We will hopefully fix this soon, but we cannot tell exactly when.

The toolkit supports Weights & Biases for monitoring jobs. If you use it, also install wandb.

Acknowledgements

This reposity contains many lines of code taken and adapted from the following sources:

  • This reposity was originally forked from the official implementation of Transformer-XL kimiyoung/transformer-xl. The code for Transformer-XL and standard Transformer models, as well as basic functionality needed for language modeling (including adaptive input and output embeddings) and data preparation (WikiText-103, enwik8, ...) is from the corresponding repository.
  • For Performers, helper functions from lucidrains/performer-pytorch are used.
  • For cuda implementations of our fast weight programmers with the delta rule:
    • Code from idiap/fast-transformers is used with minor changes for the sum update rule.
    • We modified it to implement our update rule. See comments in code for exact locations and modifications.

General Instructions

Please check files under example_scripts for general instructions and examples to train and evaluate models.

BibTex

@inproceedings{schlag2021linear,
      title={Linear Transformers Are Secretly Fast Weight Programmers}, 
      author={Imanol Schlag and Kazuki Irie and J\"urgen Schmidhuber},
      booktitle={Proc. Int. Conf. on Machine Learning (ICML)},
      address = {Virtual only},
      month = jul,
      year={2021}
}
@article{irie2021going,
      title={Going Beyond Linear Transformers with Recurrent Fast Weight Programmers}, 
      author={Kazuki Irie and Imanol Schlag and R\'obert Csord\'as and J\"urgen Schmidhuber},
      journal={Preprint arXiv:2106.06295},
      year={2021}
}

Links