Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tt_conv is not implemented here #10

Open
Silk760 opened this issue Jul 21, 2020 · 1 comment
Open

tt_conv is not implemented here #10

Silk760 opened this issue Jul 21, 2020 · 1 comment

Comments

@Silk760
Copy link

Silk760 commented Jul 21, 2020

I have read the paper, it is very interesting work, but I was thinking to use it with Conv layers, but the Conv layer is not implemented.

I have researched other Github repository, there is no PyTorch implementation of the Conv layer. Can you help me out about if there any code or some work that shows how can be implemented in PyTorch?

@elena-orlova
Copy link
Collaborator

Yes, convolutional layers are not the focus of our paper. You can have a look at this repository https://github.com/musco-ai/musco-pytorch , like in musco/pytorch/compressor/decompositions/ directory. However, there is no example of tt_conv, but there's pretty similar Tucker-2 decomposition for conv layers.
Also, another close example is available here https://github.com/NVlabs/conv-tt-lstm This paper is devoted to a compression of conv_LSTM layers with TT decomposition.

In general, it's common to reshape a 4D tensor (convolution) (C_in, C_out, k, k) to a 3D (C_in, C_out, k*k), where C_in - number of input channels, C_out - number of output channels and k - a kernel size. So, applying a TT decomposition, you get 3 cores. I'd suggest that you can implement this as a custom class (where r1 and r2 are ranks of TT decomposition) in such manner:

class tt_conv2d(Module):
   ...
   # inside your class
   self.conv1 = nn.Conv2d(c_in, r1*r2, 1, stride=1, padding=0) 
   self.conv2_weight = nn.Parameter(torch.Tensor(1, r1, *kernel_size))
   self.conv3 = nn.Conv2d(r2, c_out, 1, stride=1, padding=0) 

def forward(self, input):
   out = self.conv1(input)
   out = nn.functional.conv2d(out, self.conv2_weight.repeat(r2, 1, 1, 1),
             stride=stride, groups=r2, dilation= dilation, padding= padding)
   out = self.conv3(out)
return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants