Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor package #3

Draft
wants to merge 1 commit into
base: cnn
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion butterfly/__init__.py

This file was deleted.

24 changes: 0 additions & 24 deletions butterfly/factor_multiply/setup.py

This file was deleted.

8 changes: 2 additions & 6 deletions tests/test_butterfly.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import math
import unittest

import numpy as np

import torch

from butterfly import Butterfly
from butterfly.butterfly import ButterflyBmm

from torch_butterfly.butterfly import Butterfly
from torch_butterfly.butterfly import ButterflyBmm

class ButterflyTest(unittest.TestCase):

Expand Down
14 changes: 5 additions & 9 deletions tests/test_butterfly_multiply.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import math
import unittest

import torch
import torch.nn.functional as F

from butterfly import Butterfly
from cnn.models.butterfly_conv import ButterflyConv2d
from torch_butterfly import Butterfly

from butterfly.butterfly_multiply import butterfly_mult_torch, butterfly_mult, butterfly_mult_inplace, butterfly_mult_factors
from butterfly.butterfly_multiply import butterfly_mult_untied_torch, butterfly_mult_untied
from butterfly.butterfly_multiply import butterfly_mult_conv2d_torch, butterfly_mult_conv2d
from butterfly.butterfly_multiply import butterfly_mult_untied_svd_torch, butterfly_mult_untied_svd
from torch_butterfly.butterfly_multiply import butterfly_mult_torch, butterfly_mult, butterfly_mult_inplace, butterfly_mult_factors
from torch_butterfly.butterfly_multiply import butterfly_mult_untied_torch, butterfly_mult_untied
from torch_butterfly.butterfly_multiply import butterfly_mult_conv2d_torch, butterfly_mult_conv2d
from torch_butterfly.butterfly_multiply import butterfly_mult_untied_svd_torch, butterfly_mult_untied_svd


class ButterflyMultTest(unittest.TestCase):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_permutation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import math
import unittest

import numpy as np

import torch

from butterfly.permutation import Permutation, FixedPermutation, PermutationFactor
from torch_butterfly.permutation import Permutation, FixedPermutation, PermutationFactor


class PermutationTest(unittest.TestCase):
Expand Down
7 changes: 2 additions & 5 deletions tests/test_permutation_multiply.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import math
import unittest

import torch

from butterfly.permutation_multiply import permutation_mult_torch, permutation_mult
from butterfly.permutation_multiply import permutation_mult_single_factor_torch, permutation_mult_single
from torch_butterfly.permutation_multiply import permutation_mult_torch, permutation_mult
from torch_butterfly.permutation_multiply import permutation_mult_single_factor_torch, permutation_mult_single


class PermutationMultTest(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions torch_butterfly/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .butterfly import Butterfly
8 changes: 2 additions & 6 deletions butterfly/benchmark.py → torch_butterfly/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import os, sys
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

import torch

from butterfly import Butterfly
from butterfly.butterfly_multiply import butterfly_mult, butterfly_mult_untied, butterfly_mult_untied_svd, butterfly_mult_factors, butterfly_mult_inplace
from torch_butterfly.butterfly import Butterfly
from torch_butterfly.butterfly_multiply import butterfly_mult, butterfly_mult_untied, butterfly_mult_untied_svd, butterfly_mult_factors, butterfly_mult_inplace

batch_size = 8192
n = 256
Expand Down
8 changes: 1 addition & 7 deletions cnn/benchmark_cnn.py → torch_butterfly/benchmark_cnn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import os, sys
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

import torch

from cnn.models.butterfly_conv import ButterflyConv2d
from butterfly.butterfly import ButterflyBmm
from butterfly.butterfly_multiply import butterfly_conv2d
from torch_butterfly.butterfly import ButterflyConv2d

import time
nsteps = 1000
Expand Down
219 changes: 217 additions & 2 deletions butterfly/butterfly.py → torch_butterfly/butterfly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import torch
from torch import nn
import torch.nn.functional as F

from .butterfly_multiply import butterfly_mult, butterfly_mult_untied, butterfly_mult_untied_svd
from torch_butterfly.butterfly_multiply import butterfly_mult, butterfly_mult_untied, butterfly_mult_untied_svd
from torch_butterfly.butterfly_multiply import butterfly_mult_conv2d

class Butterfly(nn.Module):
"""Product of log N butterfly factors, each is a block 2x2 of diagonal matrices.
Expand Down Expand Up @@ -218,4 +220,217 @@ def post_process(self, output, batch):
output = output.view(batch, self.matrix_batch, self.in_size_extended // out_size_extended, out_size_extended, 2).mean(dim=2)
if self.out_size != out_size_extended: # Take top rows
output = output[:, :, :self.out_size]
return output if self.bias is None else output + self.bias
return output if self.bias is None else output + self.bias

class Butterfly1x1Conv(Butterfly):
"""Product of log N butterfly factors, each is a block 2x2 of diagonal matrices.
"""

def forward(self, input):
"""
Parameters:
input: (batch, c, h, w) if real or (batch, c, h, w, 2) if complex
Return:
output: (batch, nstack * c, h, w) if real or (batch, nstack * c, h, w, 2) if complex
"""
# TODO: Only doing real for now
batch, c, h, w = input.shape
input_reshape = input.view(batch, c, h * w).transpose(1, 2).reshape(-1, c)
output = super().forward(input_reshape)
return output.view(batch, h * w, self.nstack * c).transpose(1, 2).view(batch, self.nstack * c, h, w)


class ButterflyConv2d(ButterflyBmm):
"""Product of log N butterfly factors, each is a block 2x2 of diagonal matrices.

Parameters:
in_channels: size of input
out_channels: size of output
kernel_size: int or (int, int)
stride: int or (int, int)
padding; int or (int, int)
dilation: int or (int, int)
bias: If set to False, the layer will not learn an additive bias.
Default: ``True``
tied_weight: whether the weights in the butterfly factors are tied.
If True, will have 4N parameters, else will have 2 N log N parameters (not counting bias)
increasing_stride: whether to multiply with increasing stride (e.g. 1, 2, ..., n/2) or
decreasing stride (e.g., n/2, n/4, ..., 1).
Note that this only changes the order of multiplication, not how twiddle is stored.
In other words, twiddle[@log_stride] always stores the twiddle for @stride.
ortho_init: whether the weight matrix should be initialized to be orthogonal/unitary.
param: The parameterization of the 2x2 butterfly factors, either 'regular' or 'ortho' or 'svd'.
'ortho' and 'svd' only support real, not complex.
max_gain: (only for svd parameterization) controls the maximum and minimum singular values
of the whole matrix (not of each factor).
For example, max_gain=10.0 means that the singular values are in [0.1, 10.0].
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True,
tied_weight=True, increasing_stride=True, ortho_init=False, param='regular', max_gain=10.0,
fused_unfold=False):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
self.fused_unfold = fused_unfold
super().__init__(in_channels, out_channels, self.kernel_size[0] * self.kernel_size[1], bias, False,
tied_weight, increasing_stride, ortho_init, param, max_gain)

def forward(self, input):
"""
Parameters:
input: (batch, c, h, w) if real or (batch, c, h, w, 2) if complex
Return:
output: (batch, nstack * c, h, w) if real or (batch, nstack * c, h, w, 2) if complex
"""
# TODO: Only doing real for now
batch, c, h, w = input.shape
h_out = (h + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
w_out = (h + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
if not self.fused_unfold or c > 1024 or not input.is_cuda:
# unfold input into patches and call batch matrix multiply
input_patches = F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride).view(
batch, c, self.kernel_size[0] * self.kernel_size[1], h_out * w_out)
input = input_patches.permute(0, 3, 2, 1).reshape(batch * h_out * w_out, self.kernel_size[0] * self.kernel_size[1], c)
output = super().forward(input)
else:
batch_out = batch * h_out * w_out
output = butterfly_mult_conv2d(self.twiddle, input, self.kernel_size[0],
self.padding[0], self.increasing_stride)
output = super().post_process(output, batch_out)
# combine matrix batches
output = output.mean(dim=1)
return output.view(batch, h_out * w_out, self.out_channels).transpose(1, 2).view(batch, self.out_channels, h_out, w_out)


class ButterflyConv2dBBT(nn.Module):
"""Product of log N butterfly factors, each is a block 2x2 of diagonal matrices.

Parameters:
in_channels: size of input
out_channels: size of output
kernel_size: int or (int, int)
stride: int or (int, int)
padding; int or (int, int)
dilation: int or (int, int)
bias: If set to False, the layer will not learn an additive bias.
Default: ``True``
nblocks: number of BBT blocks in the product
tied_weight: whether the weights in the butterfly factors are tied.
If True, will have 4N parameters, else will have 2 N log N parameters (not counting bias)
increasing_stride: whether to multiply with increasing stride (e.g. 1, 2, ..., n/2) or
decreasing stride (e.g., n/2, n/4, ..., 1).
Note that this only changes the order of multiplication, not how twiddle is stored.
In other words, twiddle[@log_stride] always stores the twiddle for @stride.
ortho_init: whether the weight matrix should be initialized to be orthogonal/unitary.
param: The parameterization of the 2x2 butterfly factors, either 'regular' or 'ortho' or 'svd'.
'ortho' and 'svd' only support real, not complex.
max_gain: (only for svd parameterization) controls the maximum and minimum singular values
of the whole BB^T matrix (not of each factor).
For example, max_gain=10.0 means that the singular values are in [0.1, 10.0].
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True,
tied_weight=True, nblocks=1, ortho_init=False, param='regular', max_gain=10.0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
self.nblocks = nblocks
max_gain_per_block = max_gain ** (1 / (2 * nblocks))
layers = []
for i in range(nblocks):
layers.append(ButterflyBmm(in_channels if i == 0 else out_channels,
out_channels, self.kernel_size[0] *
self.kernel_size[1], False, False,
tied_weight, increasing_stride=False,
ortho_init=ortho_init, param=param,
max_gain=max_gain_per_block))
layers.append(ButterflyBmm(out_channels, out_channels,
self.kernel_size[0] *
self.kernel_size[1], False, bias if i == nblocks - 1 else False,
tied_weight, increasing_stride=True,
ortho_init=ortho_init, param=param,
max_gain=max_gain_per_block))
self.layers = nn.Sequential(*layers)

def forward(self, input):
"""
Parameters:
input: (batch, c, h, w) if real or (batch, c, h, w, 2) if complex
Return:
output: (batch, nstack * c, h, w) if real or (batch, nstack * c, h, w, 2) if complex
"""
# TODO: Only doing real for now
batch, c, h, w = input.shape
h_out = (h + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
w_out = (h + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
input_patches = F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride).view(batch, c, self.kernel_size[0] * self.kernel_size[1], h_out * w_out)
input_reshape = input_patches.permute(0, 3, 2, 1).reshape(batch * h_out * w_out, self.kernel_size[0] * self.kernel_size[1], c)
output = self.layers(input_reshape).mean(dim=1)
return output.view(batch, h_out * w_out, self.out_channels).transpose(1, 2).view(batch, self.out_channels, h_out, w_out)


class ButterflyConv2dBBTBBT(nn.Module):
"""Product of log N butterfly factors, each is a block 2x2 of diagonal matrices.

Parameters:
in_channels: size of input
out_channels: size of output
kernel_size: int or (int, int)
stride: int or (int, int)
padding; int or (int, int)
dilation: int or (int, int)
bias: If set to False, the layer will not learn an additive bias.
Default: ``True``
tied_weight: whether the weights in the butterfly factors are tied.
If True, will have 4N parameters, else will have 2 N log N parameters (not counting bias)
increasing_stride: whether to multiply with increasing stride (e.g. 1, 2, ..., n/2) or
decreasing stride (e.g., n/2, n/4, ..., 1).
Note that this only changes the order of multiplication, not how twiddle is stored.
In other words, twiddle[@log_stride] always stores the twiddle for @stride.
ortho_init: whether the weight matrix should be initialized to be orthogonal/unitary.
param: The parameterization of the 2x2 butterfly factors, either 'regular' or 'ortho' or 'svd'.
'ortho' and 'svd' only support real, not complex.
max_gain: (only for svd parameterization) controls the maximum and minimum singular values
of the whole BB^T BB^T matrix (not of each factor).
For example, max_gain=10.0 means that the singular values are in [0.1, 10.0].
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True,
tied_weight=True, ortho_init=False, param='regular', max_gain=10.0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
self.layers = nn.Sequential(
ButterflyBmm(in_channels, out_channels, self.kernel_size[0] * self.kernel_size[1], False, False, tied_weight, increasing_stride=False, ortho_init=ortho_init, param=param, max_gain=max_gain ** (1 / 4)),
ButterflyBmm(out_channels, out_channels, self.kernel_size[0] * self.kernel_size[1], False, False, tied_weight, increasing_stride=True, ortho_init=ortho_init, param=param, max_gain=max_gain ** (1 / 4)),
ButterflyBmm(out_channels, out_channels, self.kernel_size[0] * self.kernel_size[1], False, False, tied_weight, increasing_stride=False, ortho_init=ortho_init, param=param, max_gain=max_gain ** (1 / 4)),
ButterflyBmm(out_channels, out_channels, self.kernel_size[0] * self.kernel_size[1], bias, False, tied_weight, increasing_stride=True, ortho_init=ortho_init, param=param, max_gain=max_gain ** (1 / 4))
)

def forward(self, input):
"""
Parameters:
input: (batch, c, h, w) if real or (batch, c, h, w, 2) if complex
Return:
output: (batch, nstack * c, h, w) if real or (batch, nstack * c, h, w, 2) if complex
"""
# TODO: Only doing real for now
batch, c, h, w = input.shape
h_out = (h + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
w_out = (h + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
input_patches = F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride).view(batch, c, self.kernel_size[0] * self.kernel_size[1], h_out * w_out)
input_reshape = input_patches.permute(0, 3, 2, 1).reshape(batch * h_out * w_out, self.kernel_size[0] * self.kernel_size[1], c)
output = self.layers(input_reshape).mean(dim=1)
return output.view(batch, h_out * w_out, self.out_channels).transpose(1, 2).view(batch, self.out_channels, h_out, w_out)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.