You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have read the paper, it is very interesting work, but I was thinking to use it with Conv layers, but the Conv layer is not implemented.
I have researched other Github repository, there is no PyTorch implementation of the Conv layer. Can you help me out about if there any code or some work that shows how can be implemented in PyTorch?
The text was updated successfully, but these errors were encountered:
Yes, convolutional layers are not the focus of our paper. You can have a look at this repository https://github.com/musco-ai/musco-pytorch , like in musco/pytorch/compressor/decompositions/ directory. However, there is no example of tt_conv, but there's pretty similar Tucker-2 decomposition for conv layers.
Also, another close example is available here https://github.com/NVlabs/conv-tt-lstm This paper is devoted to a compression of conv_LSTM layers with TT decomposition.
In general, it's common to reshape a 4D tensor (convolution) (C_in, C_out, k, k) to a 3D (C_in, C_out, k*k), where C_in - number of input channels, C_out - number of output channels and k - a kernel size. So, applying a TT decomposition, you get 3 cores. I'd suggest that you can implement this as a custom class (where r1 and r2 are ranks of TT decomposition) in such manner:
class tt_conv2d(Module):
...
# inside your class
self.conv1 = nn.Conv2d(c_in, r1*r2, 1, stride=1, padding=0)
self.conv2_weight = nn.Parameter(torch.Tensor(1, r1, *kernel_size))
self.conv3 = nn.Conv2d(r2, c_out, 1, stride=1, padding=0)
def forward(self, input):
out = self.conv1(input)
out = nn.functional.conv2d(out, self.conv2_weight.repeat(r2, 1, 1, 1),
stride=stride, groups=r2, dilation= dilation, padding= padding)
out = self.conv3(out)
return out
I have read the paper, it is very interesting work, but I was thinking to use it with Conv layers, but the Conv layer is not implemented.
I have researched other Github repository, there is no PyTorch implementation of the Conv layer. Can you help me out about if there any code or some work that shows how can be implemented in PyTorch?
The text was updated successfully, but these errors were encountered: