Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert twiddle to corresponding dense matrix #28

Open
JCBrouwer opened this issue Apr 12, 2022 · 1 comment
Open

Convert twiddle to corresponding dense matrix #28

JCBrouwer opened this issue Apr 12, 2022 · 1 comment

Comments

@JCBrouwer
Copy link

Hello, thanks for open sourcing this code, it's very interesting work!

I'm wondering how the parameters of the Butterfly layer can be converted to their corresponding dense matrix. Is code for this somewhere in the repository? (I've looked around quite a bit, but might just be missing it). Otherwise could you point me to a resource to learn how to implement this myself?

A little bit of context: I want to use a low rank parameterization of a batch of DxD matrices that have structures that should be well represented by butterfly matrices. However, this BxDxD tensor is used in a tensor product that can't be factored to a single batch matrix multiply (which means using the drop in replacement modules / butterfly_multiply function doesn't work). Therefore I'd like to just convert to a dense matrix and perform the tensor product in the regular way. I understand that this gives up the speedy CUDA implementation, but I'm willing to make this compromise because at least the number of learnable parameters will still be small.

@tridao
Copy link
Collaborator

tridao commented Apr 12, 2022

To get the dense matrix from a butterfly you can just multiply the butterfly with an identity matrix:

import torch
import torch.nn.functional as F

from torch_butterfly import Butterfly
b = Butterfly(32, 32, bias=False)
# Need to transpose since we're following nn.Linear convention of multiply input @ weight.t()
dense_matrix = b(torch.eye(32)).t()
# Check that this gives the same answer
x = torch.randn(16, 32)
assert torch.allclose(b(x), F.linear(x, dense_matrix), rtol=1e-5, atol=1e-6)  

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants