From 4040aecf2daed6e6653179fd0f277e5878f65587 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 17 Feb 2020 10:23:58 -0500 Subject: [PATCH] refactor code and make PS faster --- pytorch/cifar10.py | 16 +- pytorch/cifar10_models/densenet.py | 5 +- pytorch/cpu_kernal/setup.py | 5 - pytorch/cuda_kernel/shift_cuda.cpp | 94 --- pytorch/cuda_kernel/shift_cuda_kernel.cu | 438 -------------- pytorch/deepshift/__init__.py | 0 .../convert.py} | 52 +- pytorch/deepshift/kernels/__init__.py | 1 + pytorch/deepshift/kernels/cpu/setup.py | 14 + .../kernels/cpu/shift_cpu.cpp} | 2 +- pytorch/deepshift/kernels/cuda/__init__.py | 1 + .../kernels/cuda/convert_to_unoptimized.py | 48 ++ pytorch/deepshift/kernels/cuda/setup.py | 26 + pytorch/deepshift/kernels/cuda/shift.cu | 559 ++++++++++++++++++ pytorch/deepshift/kernels/cuda/shift_cuda.cpp | 68 +++ .../kernels/cuda/unoptimized_conv.py | 109 ++++ .../kernels/cuda/unoptimized_cuda.cpp | 56 ++ .../kernels/cuda/unoptimized_cuda_kernel.cu | 219 +++++++ .../kernels/cuda/unoptimized_linear.py | 56 ++ pytorch/deepshift/kernels/kernels.py | 74 +++ pytorch/{shift.py => deepshift/modules.py} | 193 +++--- .../{shift_q.py => deepshift/modules_q.py} | 231 ++------ pytorch/deepshift/ste.py | 90 +++ pytorch/deepshift/utils.py | 67 +++ pytorch/imagenet.py | 65 +- pytorch/install_kernels.sh | 12 + pytorch/mnist.py | 47 +- pytorch/test.py | 23 - pytorch/unoptimized/convert.py | 48 ++ pytorch/unoptimized/kernels/__init__.py | 1 + .../kernels/cuda}/setup.py | 8 +- .../unoptimized/kernels/cuda/unoptimized.cu | 219 +++++++ .../kernels/cuda/unoptimized_cuda.cpp | 56 ++ pytorch/unoptimized/kernels/kernels.py | 34 ++ pytorch/unoptimized/modules/conv.py | 93 +++ pytorch/unoptimized/modules/linear.py | 51 ++ requirements.txt | 25 - 37 files changed, 2154 insertions(+), 952 deletions(-) delete mode 100644 pytorch/cpu_kernal/setup.py delete mode 100644 pytorch/cuda_kernel/shift_cuda.cpp delete mode 100644 pytorch/cuda_kernel/shift_cuda_kernel.cu create mode 100644 pytorch/deepshift/__init__.py rename pytorch/{convert_to_shift.py => deepshift/convert.py} (57%) create mode 100644 pytorch/deepshift/kernels/__init__.py create mode 100644 pytorch/deepshift/kernels/cpu/setup.py rename pytorch/{cpu_kernal/shift_kernel.cpp => deepshift/kernels/cpu/shift_cpu.cpp} (99%) create mode 100644 pytorch/deepshift/kernels/cuda/__init__.py create mode 100644 pytorch/deepshift/kernels/cuda/convert_to_unoptimized.py create mode 100644 pytorch/deepshift/kernels/cuda/setup.py create mode 100644 pytorch/deepshift/kernels/cuda/shift.cu create mode 100644 pytorch/deepshift/kernels/cuda/shift_cuda.cpp create mode 100644 pytorch/deepshift/kernels/cuda/unoptimized_conv.py create mode 100644 pytorch/deepshift/kernels/cuda/unoptimized_cuda.cpp create mode 100644 pytorch/deepshift/kernels/cuda/unoptimized_cuda_kernel.cu create mode 100644 pytorch/deepshift/kernels/cuda/unoptimized_linear.py create mode 100644 pytorch/deepshift/kernels/kernels.py rename pytorch/{shift.py => deepshift/modules.py} (64%) rename pytorch/{shift_q.py => deepshift/modules_q.py} (56%) create mode 100644 pytorch/deepshift/ste.py create mode 100644 pytorch/deepshift/utils.py create mode 100755 pytorch/install_kernels.sh delete mode 100644 pytorch/test.py create mode 100644 pytorch/unoptimized/convert.py create mode 100644 pytorch/unoptimized/kernels/__init__.py rename pytorch/{cuda_kernel => unoptimized/kernels/cuda}/setup.py (59%) create mode 100644 pytorch/unoptimized/kernels/cuda/unoptimized.cu create mode 100644 pytorch/unoptimized/kernels/cuda/unoptimized_cuda.cpp create mode 100644 pytorch/unoptimized/kernels/kernels.py create mode 100644 pytorch/unoptimized/modules/conv.py create mode 100644 pytorch/unoptimized/modules/linear.py diff --git a/pytorch/cifar10.py b/pytorch/cifar10.py index dcdbc5f..7f3ef0d 100644 --- a/pytorch/cifar10.py +++ b/pytorch/cifar10.py @@ -26,10 +26,20 @@ from torchsummary import summary import optim -from convert_to_shift import convert_to_shift, round_shift_weights, count_layer_type +from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type +from unoptimized.convert import convert_to_unoptimized import cifar10_models as models +''' +Unfortunately, none of the pytorch repositories with ResNets on CIFAR10 provides an +implementation as described in the original paper. If you just use the torchvision's +models on CIFAR10 you'll get the model that differs in number of layers and parameters. +This is unacceptable if you want to directly compare ResNet-s on CIFAR10 with the +original paper. The purpose of resnet_cifar10 (which has been obtained from https://github.com/akamaster/pytorch_resnet_cifar10 +is to provide a valid pytorch implementation of ResNet-s for CIFAR10 as described in the original paper. +''' + model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) @@ -46,8 +56,8 @@ help='path to file to load its weights (default: none)') parser.add_argument('-s', '--shift-depth', type=int, default=0, help='how many layers to convert to shift') -parser.add_argument('-st', '--shift-type', default='Q', choices=['Q', 'PS'], - help='type of DeepShift method for training and representing weights (default: Q)') +parser.add_argument('-st', '--shift-type', default='PS', choices=['Q', 'PS'], + help='type of DeepShift method for training and representing weights (default: PS)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=200, type=int, metavar='N', diff --git a/pytorch/cifar10_models/densenet.py b/pytorch/cifar10_models/densenet.py index c0d918d..5710825 100644 --- a/pytorch/cifar10_models/densenet.py +++ b/pytorch/cifar10_models/densenet.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch.autograd import Variable -__all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet264'] +__all__ = ['densenet40', 'densenet121', 'densenet169', 'densenet201', 'densenet264'] """ densenet with basic block. @@ -129,6 +129,9 @@ def forward(self, x): return x +def densenet40(): + return densenet(depth=40) + def densenet121(): return densenet(depth=121) diff --git a/pytorch/cpu_kernal/setup.py b/pytorch/cpu_kernal/setup.py deleted file mode 100644 index 2e56d66..0000000 --- a/pytorch/cpu_kernal/setup.py +++ /dev/null @@ -1,5 +0,0 @@ -from setuptools import setup, Extension -from torch.utils import cpp_extension - -setup(name='shift_kernel', ext_modules=[cpp_extension.CppExtension('shift_kernel', ['shift_kernel.cpp'], extra_compile_args=['-fopenmp', '-O3'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) - diff --git a/pytorch/cuda_kernel/shift_cuda.cpp b/pytorch/cuda_kernel/shift_cuda.cpp deleted file mode 100644 index a6a0933..0000000 --- a/pytorch/cuda_kernel/shift_cuda.cpp +++ /dev/null @@ -1,94 +0,0 @@ -#include -#include -#include - -// CUDA forward declarations - -void linear_shift_cuda( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output); -void conv2d_shift_cuda( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding); - -void GEMM_CUDA( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding); -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -void linear_shift( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output) -{ - - CHECK_INPUT(input); - CHECK_INPUT(shift); - CHECK_INPUT(sign); - CHECK_INPUT(bias); - linear_shift_cuda(input, shift, sign, bias, output); -} - -void conv2d_shift( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding) -{ - CHECK_INPUT(input); - CHECK_INPUT(shift); - CHECK_INPUT(sign); - CHECK_INPUT(bias); - - - // printf("here\n"); - conv2d_shift_cuda(input, shift, sign, bias, output,strides ,padding ); -} - -void GEMM( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding) -{ - CHECK_INPUT(input); - CHECK_INPUT(shift); - CHECK_INPUT(sign); - CHECK_INPUT(bias); - - - // printf("here\n"); - GEMM_CUDA(input, shift, sign, bias, output,strides ,padding ); -} - - -PYBIND11_MODULE(shift_cuda_kernel, m) { - m.def("linear_shift", &linear_shift, "linear shift kernel(CUDA)"); - m.def("conv2d_shift", &conv2d_shift, "conv2d shift kernel(CUDA)"); - m.def("GEMM", &GEMM, "GEMM kernel(CUDA)"); -} \ No newline at end of file diff --git a/pytorch/cuda_kernel/shift_cuda_kernel.cu b/pytorch/cuda_kernel/shift_cuda_kernel.cu deleted file mode 100644 index dab3160..0000000 --- a/pytorch/cuda_kernel/shift_cuda_kernel.cu +++ /dev/null @@ -1,438 +0,0 @@ - -#include -#include -#include -#include -#include -// #include -template -__global__ void linear_shift_cuda_kernel( - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ shift, - const scalar_t* __restrict__ sign, - const scalar_t* __restrict__ bias, - scalar_t* __restrict__ output, - size_t input_features, - int out_height, - int out_width) -{ - if(blockIdx.x * blockDim.x + threadIdx.x < out_height * out_width){ - int idx_h = (blockIdx.x * blockDim.x + threadIdx.x) / out_width; - int idx_w = (blockIdx.x * blockDim.x + threadIdx.x) % out_width; - - for(int i = 0; i < input_features; i++){ - auto x = input[idx_h * input_features + i]; - auto s = shift[idx_w * input_features + i]; - auto y = output[blockIdx.x * blockDim.x + threadIdx.x]; - if(sign[idx_w * input_features + i] < 0){ - if(s >= 0){ - y -= (x << s); - } - else{ - y -= (x >> (-s)); - } - } - else if(sign[idx_w * input_features + i] > 0){ - if(s >= 0){ - y += (x << s); - } - else{ - y += (x >> (-s)); - } - - } - output[blockIdx.x * blockDim.x + threadIdx.x]=y; - - } - - output[blockIdx.x * blockDim.x + threadIdx.x] += bias[idx_w]; - } - -} - -template -__global__ void conv2d_shift_cuda_kernel( - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ shift, - const scalar_t* __restrict__ sign, - const scalar_t* __restrict__ bias, - scalar_t* __restrict__ output, - int filter_height, - int filter_width, - int input_features, - int out_height, - int out_width, - int strides_h, - int strides_w, - int oc, - int in_width, - int in_height) -{ - - - int xx = blockIdx.x * blockDim.x + threadIdx.x; - int yy = blockIdx.y * blockDim.y + threadIdx.y; - int idx = yy * gridDim.x * blockDim.x + xx; - int batch = blockIdx.y; - idx = idx % (gridDim.x * blockDim.x * blockDim.y); - if(idx < out_height * out_width * oc){ - - - int out_channel = idx / (out_height * out_width); - int h = (idx % (out_height * out_width)) / out_width; - int w = (idx % (out_height * out_width)) % out_width; - output[w + h * out_width + out_channel * out_width * out_height - + batch * oc * out_width * out_height] = 0; - for(int i = 0; i < filter_height; i++){ - for(int j = 0 ; j < filter_width; j++){ - for(int k = 0 ; k < input_features; k++){ - // auto s = shift[out_channel][k][i][j]; - auto s = shift[j + i * filter_width - + k * filter_height * filter_width + - out_channel * filter_height * filter_width * input_features]; - // auto y = output[batch][out_channel][h][w]; - auto y = output[w + h * out_width + out_channel * out_width * out_height - + batch * oc * out_width * out_height]; - - // auto x = input[batch][k][i + strides_h * h][j + strides_w * w]; - auto x = input[j + strides_w * w + (i + strides_h * h) * in_width - + k * in_width * in_height - + batch * in_width * in_height * input_features]; - if(sign[j + i * filter_width - + k * filter_height * filter_width + - out_channel * filter_height * filter_width * input_features] < 0){ - if(s >= 0){ - y -= (x << s); - } - else{ - y -= (x >> (-s)); - } - } - else if(sign[j + i * filter_width - + k * filter_height * filter_width + - out_channel * filter_height * filter_width * input_features] > 0){ - if(s >= 0){ - y += (x << s); - } - else{ - y += (x >> (-s)); - } - } - output[w + h * out_width + out_channel * out_width * out_height - + batch * oc * out_width * out_height] = y; - - } - } - } - - output[w + h * out_width + out_channel * out_width * out_height - + batch * oc * out_width * out_height] += bias[out_channel]; - } - - -} - -template -__global__ void im2col( - const scalar_t* __restrict__ im, - scalar_t* __restrict__ col, - int filter_height, - int filter_width, - int input_features, - int out_height, - int out_width, - int strides_h, - int strides_w, - int in_height, - int in_width, - int batch) -{ - int xx = blockIdx.x * blockDim.x + threadIdx.x; - int yy = blockIdx.y * blockDim.y + threadIdx.y; - int index = yy * gridDim.x * blockDim.x + xx; - - int k = filter_height * filter_width * input_features; - int num = out_height * out_width * batch; - if(index < k * num){ - int h = index / num; - int w = index % num; - int n = w / (out_height * out_width); - int out_idx = w % (out_height * out_width); - int h_out = out_idx / out_width; - int w_out = out_idx % out_width; - int ic = h / (filter_height * filter_width); - int hh_f = (h % (filter_height * filter_width)) / filter_width; - int ww_f = (h % (filter_height * filter_width)) % filter_width; - - col[index] = im[ww_f + strides_w * w_out + - (hh_f + strides_h * h_out) * in_width + - ic * in_width * in_height + - n * in_width * in_height * input_features]; - } -} - -template -__global__ void col2im( - const scalar_t* __restrict__ col, - scalar_t* __restrict__ im, - int filter_height, - int filter_width, - int input_features, - int out_height, - int out_width, - int strides_h, - int strides_w, - int in_height, - int in_width, - int batch, - int oc) -{ - int xx = blockIdx.x * blockDim.x + threadIdx.x; - int yy = blockIdx.y * blockDim.y + threadIdx.y; - int index = yy * gridDim.x * blockDim.x + xx; - int num = out_height * out_width * batch; - if(index < num * oc){ - int h = index / oc; - int w = index % oc; - int n = h / (out_height * out_width); - int out_idx = h % (out_height * out_width); - int h_out = out_idx / out_width; - int w_out = out_idx % out_width; - im[w_out + h_out * out_width + out_width * out_height * w + n * oc * out_width * out_height] = col[index]; - - } -} - - -template -__global__ void GEMM_CUDA_KERNEL( - const scalar_t* __restrict__ col, - const scalar_t* __restrict__ filter, - const scalar_t* __restrict__ sign, - const scalar_t* __restrict__ bias, - scalar_t* __restrict__ result, - int im_num, - int filter_num, - int k, - int filter_height, - int filter_width, - int input_features) -{ - int xx = blockIdx.x * blockDim.x + threadIdx.x; - int yy = blockIdx.y * blockDim.y + threadIdx.y; - int index = yy * gridDim.x * blockDim.x + xx; - if(index < filter_num * im_num){ - int h = index / filter_num; - int w = index % filter_num; - for(int i = 0; i < k; i++){ - auto f = filter[i * filter_num + w]; - auto y = result[w + h * filter_num]; - auto x = col[h + i * im_num]; - - auto s = sign[ i + - w * filter_height * filter_width * input_features]; - if(s < 0){ - if(f >= 0){ - y -= (x << f); - } - else{ - y -= (x >> (-f)); - } - } - else if (s > 0){ - if(f >= 0){ - y += (x << f); - } - else{ - y += (x >> (-f)); - } - } - - result[w + h * filter_num] = y; - } - result[w + h * filter_num] += bias[w]; - - } - -} - -void linear_shift_cuda( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output) -{ - - - const int block = (input.size(0) * shift.size(0) + 1024 -1) / 1024; - const int threads = 1024; - - AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { - linear_shift_cuda_kernel<<>>( - input.data(), - shift.data(), - sign.data(), - bias.data(), - output.data(), - (int)input.size(1), - input.size(0), - shift.size(0)); - })); - - -} - - - - -void conv2d_shift_cuda( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding) -{ - - int strides_h; - int strides_w; - if(strides.size() ==1){ - strides_h = strides[0]; - strides_w = strides[0]; - } - else{ - strides_h = strides[0]; - strides_w = strides[1]; - } - - int temp = (output.size(2) * output.size(3) * shift.size(0) + 1024 - 1) / 1024; - const dim3 block(temp, output.size(0)); - const dim3 threads(32,32); - AT_DISPATCH_INTEGRAL_TYPES(input.type(), "conv2d cuda", ([&] { - conv2d_shift_cuda_kernel<<>>( - input.data(), - shift.data(), - sign.data(), - bias.data(), - output.data(), - shift.size(2), - shift.size(3), - input.size(1), - output.size(2), - output.size(3), - strides_h, - strides_w, - shift.size(0), - input.size(3), - input.size(2)); - })); - -} - - -void GEMM_CUDA( - torch::Tensor& input, - torch::Tensor& shift, - torch::Tensor& sign, - torch::Tensor& bias, - torch::Tensor& output, - torch::IntArrayRef strides, - torch::IntArrayRef padding) -{ - - int strides_h; - int strides_w; - if(strides.size() ==1){ - strides_h = strides[0]; - strides_w = strides[0]; - } - else{ - strides_h = strides[0]; - strides_w = strides[1]; - } - - int k = shift.size(2)*shift.size(3)*shift.size(1); - int num_p = output.size(2)*output.size(3)*output.size(0); - auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, 0); - auto col = torch::zeros({k, num_p},options); - - int tmp = (k * num_p + 1024 -1) / 1024; - int tmp1 = (tmp + 65535 -1) / 65535; - tmp = (tmp > 65535) ? 65535: tmp; - const dim3 blk(tmp,tmp1); - AT_DISPATCH_INTEGRAL_TYPES(input.type(), "im2col cuda", ([&] { - im2col<<>>( - input.data(), - col.data(), - shift.size(2), - shift.size(3), - input.size(1), - output.size(2), - output.size(3), - strides_h, - strides_w, - input.size(2), - input.size(3), - input.size(0)); - })); - int filter_p = 1 * 1 * shift.size(0); - auto filter = torch::zeros({k, filter_p},options); - tmp = (k * filter_p + 1024 -1) / 1024; - tmp1 = (tmp + 65535 -1) / 65535; - tmp = (tmp > 65535) ? 65535: tmp; - const dim3 block(tmp,tmp1); - AT_DISPATCH_INTEGRAL_TYPES(shift.type(), "im2col cuda", ([&] { - im2col<<>>( - shift.data(), - filter.data(), - shift.size(2), - shift.size(3), - shift.size(1), - 1, - 1, - strides_h, - strides_w, - shift.size(2), - shift.size(3), - shift.size(0)); - })); - - tmp = (num_p * filter_p + 1024 -1) / 1024; - tmp1 = (tmp + 65535 -1) / 65535; - tmp = (tmp > 65535) ? 65535: tmp; - const dim3 block1(tmp,tmp1); - auto result = torch::zeros({num_p, filter_p},options); - AT_DISPATCH_INTEGRAL_TYPES(shift.type(), "GEMM_CUDA_KERNEL", ([&] { - GEMM_CUDA_KERNEL<<>>( - col.data(), - filter.data(), - sign.data(), - bias.data(), - result.data(), - num_p, - filter_p, - k, - shift.size(2), - shift.size(3), - shift.size(1)); - })); - AT_DISPATCH_INTEGRAL_TYPES(result.type(), "col2im cuda", ([&] { - col2im<<>>( - result.data(), - output.data(), - shift.size(2), - shift.size(3), - input.size(1), - output.size(2), - output.size(3), - strides_h, - strides_w, - input.size(2), - input.size(3), - input.size(0), - shift.size(0)); - })); - -} \ No newline at end of file diff --git a/pytorch/deepshift/__init__.py b/pytorch/deepshift/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch/convert_to_shift.py b/pytorch/deepshift/convert.py similarity index 57% rename from pytorch/convert_to_shift.py rename to pytorch/deepshift/convert.py index 95dffeb..02ca14b 100644 --- a/pytorch/convert_to_shift.py +++ b/pytorch/deepshift/convert.py @@ -1,9 +1,12 @@ import torch import torch.nn as nn import numpy as np +import math +import copy -import shift, shift_q -from shift import round_to_fixed, get_shift_and_sign, round_power_of_2 +import deepshift.modules +import deepshift.modules_q +import deepshift.utils as utils def convert_to_shift(model, shift_depth, shift_type, convert_all_linear=True, convert_weights=False, freeze_sign = False, use_kernel=False, use_cuda=True): conversion_count = 0 @@ -16,16 +19,22 @@ def convert_to_shift(model, shift_depth, shift_type, convert_all_linear=True, co linear = module if shift_type == 'Q': - shift_linear = shift_q.LinearShiftQ(module.in_features, module.out_features, module.bias is not None, use_kernel=use_kernel, use_cuda = use_cuda) + shift_linear = deepshift.modules_q.LinearShiftQ(module.in_features, module.out_features, module.bias is not None, use_kernel=use_kernel, use_cuda = use_cuda) shift_linear.weight = linear.weight if linear.bias is not None: - shift_linear.bias.data = round_to_fixed(linear.bias, fraction=16, integer=16) + shift_linear.bias.data = utils.round_to_fixed(linear.bias, fraction=16, integer=16) + + if use_cuda==True and use_kernel == True: + shift_linear.conc_weight = utils.compress_bits(*utils.get_shift_and_sign(linear.weight)) elif shift_type == 'PS': - shift_linear = shift.LinearShift(module.in_features, module.out_features, module.bias is not None, freeze_sign = freeze_sign, use_kernel=use_kernel, use_cuda = use_cuda) + shift_linear = deepshift.modules.LinearShift(module.in_features, module.out_features, module.bias is not None, freeze_sign = freeze_sign, use_kernel=use_kernel, use_cuda = use_cuda) if convert_weights == True: - shift_linear.shift.data, shift_linear.sign.data = get_shift_and_sign(linear.weight) + shift_linear.shift.data, shift_linear.sign.data = utils.get_shift_and_sign(linear.weight) shift_linear.bias = linear.bias + + if use_cuda==True and use_kernel == True: + shift_linear.conc_weight = utils.compress_bits(shift_linear.shift.data, shift_linear.sign.data) else: raise ValueError('Unsupported shift_type argument: ', shift_type) @@ -37,48 +46,55 @@ def convert_to_shift(model, shift_depth, shift_type, convert_all_linear=True, co conv2d = module if shift_type == 'Q': - shift_conv2d = shift_q.Conv2dShiftQ(module.in_channels, module.out_channels, module.kernel_size, module.stride, + shift_conv2d = deepshift.modules_q.Conv2dShiftQ(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, use_kernel=use_kernel, use_cuda=use_cuda) shift_conv2d.weight = conv2d.weight if conv2d.bias is not None: - shift_conv2d.bias.data = round_to_fixed(conv2d.bias, fraction=16, integer=16) + shift_conv2d.bias.data = utils.round_to_fixed(conv2d.bias, fraction=16, integer=16) + + if use_cuda==True and use_kernel == True: + shift_conv2d.conc_weight = utils.compress_bits(*utils.get_shift_and_sign(conv2d.weight)) elif shift_type == 'PS': - shift_conv2d = shift.Conv2dShift(module.in_channels, module.out_channels, module.kernel_size, module.stride, + shift_conv2d = deepshift.modules.Conv2dShift(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, freeze_sign=freeze_sign, use_kernel=use_kernel, use_cuda=use_cuda) if convert_weights == True: - shift_conv2d.shift.data, shift_conv2d.sign.data = get_shift_and_sign(conv2d.weight) + shift_conv2d.shift.data, shift_conv2d.sign.data = utils.get_shift_and_sign(conv2d.weight) shift_conv2d.bias = conv2d.bias + if use_cuda==True and use_kernel == True: + shift_conv2d.conc_weight = utils.compress_bits(shift_conv2d.shift.data, shift_conv2d.sign.data) + model._modules[name] = shift_conv2d conversion_count += 1 return model, conversion_count -def round_shift_weights(model): +def round_shift_weights(model, clone=False): + if(clone): + model = copy.deepcopy(model) + for name, module in reversed(model._modules.items()): if len(list(module.children())) > 0: # recurse model._modules[name] = round_shift_weights(model=module) - print(type(module)) - - if type(module) == shift.LinearShift or type(module) == shift.Conv2dShift: + if type(module) == deepshift.modules.LinearShift or type(module) == deepshift.modules.Conv2dShift: module.shift.data = module.shift.round() module.sign.data = module.sign.round().sign() if (module.bias is not None): - module.bias.data = round_to_fixed(module.bias, fraction=16, integer=16) - elif type(module) == shift_q.LinearShiftQ or type(module) == shift_q.Conv2dShiftQ: - module.weight.data = round_power_of_2(module.weight) + module.bias.data = utils.round_to_fixed(module.bias, fraction=16, integer=16) + elif type(module) == deepshift.modules_q.LinearShiftQ or type(module) == deepshift.modules_q.Conv2dShiftQ: + module.weight.data = utils.round_power_of_2(module.weight) if (module.bias is not None): - module.bias.data = round_to_fixed(module.bias, fraction=16, integer=16) + module.bias.data = utils.round_to_fixed(module.bias, fraction=16, integer=16) return model diff --git a/pytorch/deepshift/kernels/__init__.py b/pytorch/deepshift/kernels/__init__.py new file mode 100644 index 0000000..2cdd6b4 --- /dev/null +++ b/pytorch/deepshift/kernels/__init__.py @@ -0,0 +1 @@ +from .kernels import * \ No newline at end of file diff --git a/pytorch/deepshift/kernels/cpu/setup.py b/pytorch/deepshift/kernels/cpu/setup.py new file mode 100644 index 0000000..1e65cca --- /dev/null +++ b/pytorch/deepshift/kernels/cpu/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup( + name='deepshift_cpu', + ext_modules=[ + cpp_extension.CppExtension('deepshift_cpu', [ + 'shift_cpu.cpp' + ], extra_compile_args=['-fopenmp', '-O3']) + ], + cmdclass={ + 'build_ext': cpp_extension.BuildExtension + }) + diff --git a/pytorch/cpu_kernal/shift_kernel.cpp b/pytorch/deepshift/kernels/cpu/shift_cpu.cpp similarity index 99% rename from pytorch/cpu_kernal/shift_kernel.cpp rename to pytorch/deepshift/kernels/cpu/shift_cpu.cpp index 175ff05..547110a 100644 --- a/pytorch/cpu_kernal/shift_kernel.cpp +++ b/pytorch/deepshift/kernels/cpu/shift_cpu.cpp @@ -373,7 +373,7 @@ vector>>> convolution_kernel( return output; } -PYBIND11_MODULE(shift_kernel, m) { +PYBIND11_MODULE(deepshift_cpu, m) { m.def("linear_kernel", &linear_kernel, "linear_kernel"); m.def("convolution_kernel", &convolution_kernel, "convolution_kernel"); } diff --git a/pytorch/deepshift/kernels/cuda/__init__.py b/pytorch/deepshift/kernels/cuda/__init__.py new file mode 100644 index 0000000..ddf94f2 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/__init__.py @@ -0,0 +1 @@ +from .convert_to_unoptimized import * \ No newline at end of file diff --git a/pytorch/deepshift/kernels/cuda/convert_to_unoptimized.py b/pytorch/deepshift/kernels/cuda/convert_to_unoptimized.py new file mode 100644 index 0000000..a0da75b --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/convert_to_unoptimized.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import numpy as np + +from .unoptimized_linear import UnoptimizedLinear +from .unoptimized_conv import UnoptimizedConv2d + +def convert_to_unoptimized(model): + for name, module in model._modules.items(): + if len(list(module.children())) > 0: + # recurse + model._modules[name] = convert_to_unoptimized(model=module) + if type(module) == nn.Linear: + linear = module + unoptimized_linear = UnoptimizedLinear(module.in_features, module.out_features, module.bias is not None) + unoptimized_linear.weight = linear.weight + unoptimized_linear.bias = linear.bias + + model._modules[name] = unoptimized_linear + if type(module) == nn.Conv2d: + conv2d = module + unoptimized_conv = UnoptimizedConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, + module.padding, module.dilation, module.groups, + module.bias is not None, module.padding_mode) + unoptimized_conv.bias = conv2d.bias + unoptimized_conv.weight = conv2d.weight + + model._modules[name] = unoptimized_conv + + return model + + +if __name__ == '__main__': + # this test will be run if you type in the command: + # > python convert_to_unoptimized + import torchvision.models as models + model = models.__dict__['resnet18'](pretrained=True) + model = model.to("cuda:0") + input = torch.rand((32, 3, 224, 224)).to("cuda:0") + output1 = model(input) + + + model = convert_to_unoptimized(model).to("cuda:0") + output2 = model(input) + + max_error = torch.max(torch.abs(output1 - output2)) + print(max_error.detach().cpu().numpy()) + diff --git a/pytorch/deepshift/kernels/cuda/setup.py b/pytorch/deepshift/kernels/cuda/setup.py new file mode 100644 index 0000000..4f53bff --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/setup.py @@ -0,0 +1,26 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='deepshift_cuda', + ext_modules=[ + CUDAExtension('deepshift_cuda', [ + 'shift_cuda.cpp', + 'shift.cu', + ],extra_compile_args=['-O3']) + ], + cmdclass={ + 'build_ext': BuildExtension + }) + +setup( + name='unoptimized_cuda_kernel', + ext_modules=[ + CUDAExtension('unoptimized_cuda_kernel', [ + 'unoptimized_cuda.cpp', + 'unoptimized_cuda_kernel.cu', + ],extra_compile_args=['-O3']) + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/pytorch/deepshift/kernels/cuda/shift.cu b/pytorch/deepshift/kernels/cuda/shift.cu new file mode 100644 index 0000000..f93bea5 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/shift.cu @@ -0,0 +1,559 @@ + +#include +#include +#include +#include +#include +#define BLOCK_SIZE 16 +#define MAX_BITS 32 +#define MAX_THREADS 1024 +#define MAX_BLOCKS 65535 +#define ZERO_BASE 1 +#define NON_ZERO_BASE 0 +#define BIT_3 3 +#define BIT_4 4 +#define BIT_5 5 +#define BIT_6 6 +#define BIT_7 7 +#define NUM_4 4 +#define NUM_5 5 +#define NUM_6 6 +#define NUM_8 8 + +__device__ int COMPRESS(const int* __restrict__ shift, const int* __restrict__ sign, int length, int base, int bits) +{ + int value = 0; + int s = 0; + for (int i = 0; i < length; i++) { + value = (value) | ((shift[i] - base) << s); + s = s + bits; + value = (value) | ((sign[i] > 0 ? 1 : 0) << s); + s = s + 1; + } + return value; +} + +__global__ void COMPRESS_SIGN_SHIFT_GPU_KERNEL( int* __restrict__ shift, int* __restrict__ sign, int* __restrict__ weight, int oc, int in_c, int height, int width, int num,int base, int bits, int compressed_row_length, int row_length) +{ + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < compressed_row_length) { + int* shift_sub = &shift[oc * in_c * height * width]; + int* sign_sub = &sign[oc * in_c * height * width]; + int* weight_sub = &weight[oc * compressed_row_length]; + int length = num; + if( (index + 1) * num >= in_c * height * width) { + length = in_c * height * width - index * num; + } + weight_sub[index] = COMPRESS(&shift_sub[index * num], &sign_sub[index * num], length, base, bits); + + } + __syncthreads(); +} + +template +__global__ void DEEP_SHIFT_GEMM_GPU_KERNEL( + const int* __restrict__ input, + const int* __restrict__ shift, + const int* __restrict__ bias, + int* __restrict__ output, + const int n, + const int m, + const int k, + const int base, + const int max, + const int row_length) +{ + const int row = threadIdx.y; + const int col = threadIdx.x; + for(int blockRow = blockIdx.y;blockDim.y * blockRow < m; blockRow = blockRow + gridDim.y){ + for(int blockCol = blockIdx.x;blockDim.x * blockCol < k; blockCol = blockCol + gridDim.x){ + + const int compressed_row = row_length / num; + int* Csub = &output[BLOCK_SIZE * k * blockRow + BLOCK_SIZE * blockCol]; + __shared__ int As[BLOCK_SIZE *BLOCK_SIZE * num]; + __shared__ int Bs[BLOCK_SIZE *BLOCK_SIZE]; + int Cvalue = 0; + for (int i = 0; i < max; ++i) { + const int* Asub = &input[BLOCK_SIZE * blockRow * n + BLOCK_SIZE * i * num ]; + const int* Bsub = &shift[(BLOCK_SIZE * blockCol * row_length + BLOCK_SIZE * i * num) / num]; + #pragma unroll + for( int d = 0; d < num; d++) { + As[row * BLOCK_SIZE * num + col * num + d] = Asub[row * n + col * num + d]; + } + Bs[row * BLOCK_SIZE + col] = Bsub[row * compressed_row + col]; + + __syncthreads(); + + #pragma unroll + for (int j = 0; j < BLOCK_SIZE ; ++j){ + if(col + blockCol* BLOCK_SIZE< k + && row + blockRow* BLOCK_SIZE< m ){ + int whole = Bs[col * BLOCK_SIZE + j]; + #pragma unroll + for(int d = 0; d < num; d++) { + if(i * BLOCK_SIZE * num + j * num + d < n){ + int get_sign = int(whole & mask_sign); + get_sign = get_sign == 0 ? -1 : 1; + int get_shift = int(whole & mask_shift); + whole = int(whole >> (bits + 1)); + if(zero_base){ + Cvalue += get_sign * (As[row * BLOCK_SIZE * num+ j * num + d] >> (get_shift)); + } + else{ + Cvalue += get_sign * ((As[row * BLOCK_SIZE * num+ j * num + d] >> (get_shift))<<(-base)); + } + } + } + } + } + __syncthreads(); + } + if(col + blockCol* BLOCK_SIZE< k && row + blockRow* BLOCK_SIZE< m) Csub[row*k+col] = Cvalue + bias[col + blockCol* BLOCK_SIZE]; + __syncthreads(); + } + + } +} + +__global__ void IM2COL( + const int total, + const int* __restrict__ im, + int* __restrict__ col, + const int filter_height, + const int filter_width, + const int input_features, + const int out_height, + const int out_width, + const int strides_h, + const int strides_w, + const int in_height, + const int in_width, + const int k) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x) { + const int h = index / k; + const int w = index % k; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + const int ic = w / (filter_height * filter_width); + const int hh_f = (w % (filter_height * filter_width)) / filter_width; + const int ww_f = (w % (filter_height * filter_width)) % filter_width; + + col[index] = im[ww_f + strides_w * w_out + + (hh_f + strides_h * h_out) * in_width + + ic * in_width * in_height + + n * in_width * in_height * input_features]; + } +} + +__global__ void COL2IM( + const int total, + const int* __restrict__ col, + int* __restrict__ im, + const int out_height, + const int out_width, + const int oc) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x){ + const int h = index / oc; + const int w = index % oc; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + im[w_out + h_out * out_width + out_width * out_height * w + n * oc * out_width * out_height] = col[index]; + } +} + +void DEEP_SHIFT_LINEAR_GPU( + torch::Tensor input, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + int base, int bits, int out_features) +{ + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + int a1=out_features/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=input.size(0) / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2= MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + int num = int(MAX_BITS / (bits + 1)); + int comm = (input.size(1) + num - 1) / num; + int max =(comm + BLOCK_SIZE - 1) / BLOCK_SIZE; + if(bits == 3) { + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + + } + else if(bits == 4){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + + } + else if(bits == 5){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features,base, max, comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features,base, max, comm * num); + })); + } + + } + else if(bits == 6){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base, max, comm * num); + })); + } + else{ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + + } + else if(bits == 7){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(input.type(), "linear shift kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + input.data(), + shift.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + out_features, base,max, comm * num); + })); + } + + } + else{ + std::cout<<"ERROR: unhandled case\n"; + } + +} +void COMPRESS_SIGN_SHIFT_GPU(torch::Tensor shift, torch::Tensor sign, torch::Tensor weight, int base, int bits, int out_c, int in_c, int height, int width, int row_length, int num) +{ + int threads = MAX_THREADS; + int compressed_row_length = row_length; + dim3 block ( (compressed_row_length + threads - 1) / threads); + for(int i = 0; i < out_c; i++) { + COMPRESS_SIGN_SHIFT_GPU_KERNEL<<>>(shift.data(), sign.data(),weight.data(), + i, in_c, height, width, num, base, bits,compressed_row_length, row_length); + } +} + +void DEEP_SHIFT_CONV_GPU(torch::Tensor data_im, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding, int filter_height, int filter_width, int base, int bits) +{ + int strides_h; + int strides_w; + if(strides.size() ==1){ + strides_h = strides[0]; + strides_w = strides[0]; + } + else{ + strides_h = strides[0]; + strides_w = strides[1]; + } + int k = filter_height * filter_width * data_im.size(1); + int num_patch = output.size(0) * output.size(2) * output.size(3); + + int* data_col; + cudaMalloc(&data_col, num_patch * k * sizeof(int)); + + int threads = MAX_THREADS; + int tmp = (k * num_patch + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 blk(tmp); + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "IM2COL cuda", ([&] { + IM2COL<<>>( + k * num_patch, + data_im.data(), + data_col, + filter_height, + filter_width, + data_im.size(1), + output.size(2), + output.size(3), + strides_h, + strides_w, + data_im.size(2), + data_im.size(3), + k); + })); + + int filter_patch = output.size(1); + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + + int a1=filter_patch/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=num_patch / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2 = MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + + int *out_col; + int num = int(MAX_BITS / (bits +1 )); + int comm = (k + num -1 ) / num; + int max =(comm + BLOCK_SIZE - 1) / BLOCK_SIZE; + cudaMalloc(&out_col, num_patch * filter_patch * sizeof(int)); + if(bits == 3) { + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + + } + else if(bits == 4) { + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + + } + else if(bits == 5){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + + } + else if(bits == 6){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + else{ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + + } + else if(bits == 7){ + if(base == 0){ + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + else { + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "DEEP_SHIFT_GEMM_GPU_KERNEL kernel", ([&] { + DEEP_SHIFT_GEMM_GPU_KERNEL<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_patch, + filter_patch, base,max,comm * num); + })); + } + + } + else{ + std::cout<<"ERROR: unhandled case\n"; + } + + tmp = (num_patch * output.size(1) + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 block1(tmp); + AT_DISPATCH_INTEGRAL_TYPES(data_im.type(), "COL2IM cuda", ([&] { + COL2IM<<>>( + num_patch * output.size(1), + out_col, + output.data(), + output.size(2), + output.size(3), + output.size(1)); + })); + cudaFree(data_col); + cudaFree(out_col); +} + + + diff --git a/pytorch/deepshift/kernels/cuda/shift_cuda.cpp b/pytorch/deepshift/kernels/cuda/shift_cuda.cpp new file mode 100644 index 0000000..7ffb480 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/shift_cuda.cpp @@ -0,0 +1,68 @@ +#include +#include +#include + +// CUDA forward declarations +void COMPRESS_SIGN_SHIFT_GPU(torch::Tensor shift, torch::Tensor sign, torch::Tensor weight, + int base, int bits, int out_c, int in_c, int height, int width, int row_length, int num); +void DEEP_SHIFT_CONV_GPU(torch::Tensor data_im, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding, int filter_height, int filter_width, int base, int bits); +void DEEP_SHIFT_LINEAR_GPU( + torch::Tensor input, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + int base, int bits, int out_features); +// C++ interface + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +void DEEP_SHIFT_LINEAR( + torch::Tensor input, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + int base, int bits, int out_features) +{ + + CHECK_INPUT(input); + CHECK_INPUT(shift); + CHECK_INPUT(bias); + CHECK_INPUT(output); + DEEP_SHIFT_LINEAR_GPU(input, shift, bias, output, base, bits, out_features); +} + +void DEEP_SHIFT_CONV(torch::Tensor data_im, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding, int filter_height, int filter_width, int base, int bits) +{ + CHECK_INPUT(data_im); + CHECK_INPUT(shift); + CHECK_INPUT(bias); + CHECK_INPUT(output); + DEEP_SHIFT_CONV_GPU(data_im, shift, bias, output, strides, padding, filter_height, filter_width, base,bits); +} + +void COMPRESS_SIGN_SHIFT(torch::Tensor shift, torch::Tensor sign, torch::Tensor weight, int base, int bits, int out_c, int in_c, int height, int width, int row_length, int num) +{ + CHECK_INPUT(shift); + CHECK_INPUT(sign); + CHECK_INPUT(weight); + COMPRESS_SIGN_SHIFT_GPU(shift,sign, weight, base,bits, out_c, in_c, height,width,row_length, num); +} + + +PYBIND11_MODULE(deepshift_cuda, m) { + m.def("DEEP_SHIFT_LINEAR", &DEEP_SHIFT_LINEAR, "DEEP_SHIFT_LINEAR kernel(CUDA)"); + m.def("DEEP_SHIFT_CONV", &DEEP_SHIFT_CONV, "DEEP_SHIFT_CONV kernel(CUDA)"); + m.def("COMPRESS_SIGN_SHIFT", &COMPRESS_SIGN_SHIFT, "COMPRESS_SIGN_SHIFT kernel(CUDA)"); +} diff --git a/pytorch/deepshift/kernels/cuda/unoptimized_conv.py b/pytorch/deepshift/kernels/cuda/unoptimized_conv.py new file mode 100644 index 0000000..4b16780 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/unoptimized_conv.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.utils import _pair +from torch.nn import init +import unoptimized_cuda_kernel +import math +import numpy as np +import time + +class _UnoptimizedConvNd(nn.Module): + + __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode'] + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode): + super(_UnoptimizedConvNd, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + + if transposed: + self.weight = nn.Parameter(torch.Tensor( + in_channels, out_channels // groups, *kernel_size)) + else: + self.weight = nn.Parameter(torch.Tensor( + out_channels, in_channels // groups, *kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + return s.format(**self.__dict__) + +class UnoptimizedConv2d(_UnoptimizedConvNd): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros'): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(UnoptimizedConv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode) + + #@weak_script_method + def forward(self, input): + # start_time = time.time() + if self.padding_mode == 'circular': + print('circular') + if len(self.padding) == 2: + padding = (self.padding[0],self.padding[0],self.padding[1],self.padding[1]) + else: + padding = self.padding + input = F.pad(input = input, pad = padding, mode = 'constant', value = 0) + if len(self.stride) == 1: + strides_h = self.stride[0] + strides_w = self.stride[0] + else: + strides_h = self.stride[0] + strides_w = self.stride[1] + out_height = int((input.size(2) - self.weight.size(2)) / strides_h +1) + out_width = int((input.size(3) - self.weight.size(3)) / strides_w +1) + out = torch.zeros([input.size(0), self.weight.size(0), out_height, out_width], dtype=torch.float, device=torch.device('cuda:0')) + + if self.bias is not None: + unoptimized_cuda_kernel.UNOPTIMIZED_CONV(input, self.weight, self.bias, out, self.stride, self.padding ) + else: + temp = torch.zeros([self.weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + unoptimized_cuda_kernel.UNOPTIMIZED_CONV(input, self.weight, temp, out, self.stride, self.padding ) + # end_time = time.time() + # print("Conv Time:", end_time - start_time ) + return out diff --git a/pytorch/deepshift/kernels/cuda/unoptimized_cuda.cpp b/pytorch/deepshift/kernels/cuda/unoptimized_cuda.cpp new file mode 100644 index 0000000..48e4b11 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/unoptimized_cuda.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +// CUDA forward declarations + +void UNOPTIMIZED_LINEAR_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output); +void UNOPTIMIZED_CONV_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding); +// C++ interface + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +void UNOPTIMIZED_LINEAR( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output) +{ + + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + UNOPTIMIZED_LINEAR_GPU(input, weight, bias, output); +} + +void UNOPTIMIZED_CONV( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding) +{ + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + + UNOPTIMIZED_CONV_GPU(input, weight, bias, output,strides ,padding ); +} + +PYBIND11_MODULE(unoptimized_cuda_kernel, m) { + m.def("UNOPTIMIZED_LINEAR", &UNOPTIMIZED_LINEAR, "UNOPTIMIZED_LINEAR kernel(CUDA)"); + m.def("UNOPTIMIZED_CONV", &UNOPTIMIZED_CONV, "UNOPTIMIZED_CONV kernel(CUDA)"); +} diff --git a/pytorch/deepshift/kernels/cuda/unoptimized_cuda_kernel.cu b/pytorch/deepshift/kernels/cuda/unoptimized_cuda_kernel.cu new file mode 100644 index 0000000..28a10d0 --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/unoptimized_cuda_kernel.cu @@ -0,0 +1,219 @@ + +#include +#include +#include +#include +#include +#define BLOCK_SIZE 16 +#define MAX_THREADS 1024 +#define MAX_BLOCKS 65535 +__global__ void IM2COL( + const int total, + const float* __restrict__ im, + float* __restrict__ col, + const int filter_height, + const int filter_width, + const int input_features, + const int out_height, + const int out_width, + const int strides_h, + const int strides_w, + const int in_height, + const int in_width, + const int k, const int num) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x) { + const int h = index / k; + const int w = index % k; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + const int ic = w / (filter_height * filter_width); + const int hh_f = (w % (filter_height * filter_width)) / filter_width; + const int ww_f = (w % (filter_height * filter_width)) % filter_width; + + col[index] = im[ww_f + strides_w * w_out + + (hh_f + strides_h * h_out) * in_width + + ic * in_width * in_height + + n * in_width * in_height * input_features]; + } +} + + +__global__ void COL2IM( + const int total, + const float* __restrict__ col, + float* __restrict__ im, + const int out_height, + const int out_width, + const int oc) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x){ + const int h = index / oc; + const int w = index % oc; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + im[w_out + h_out * out_width + out_width * out_height * w + n * oc * out_width * out_height] = col[index]; + } +} + +__global__ void GEMM( + const float* __restrict__ input, + const float* __restrict__ shift, + const float* __restrict__ bias, + float* __restrict__ output, + const int n, + const int m, + const int k, + const int max) +{ + + const int row = threadIdx.y; + const int col = threadIdx.x; + for(int blockRow = blockIdx.y;blockDim.y * blockRow < m; blockRow = blockRow + gridDim.y){ + for(int blockCol = blockIdx.x;blockDim.x * blockCol < k; blockCol = blockCol + gridDim.x){ + float* Csub = &output[BLOCK_SIZE * k * blockRow + BLOCK_SIZE * blockCol]; + __shared__ float As[BLOCK_SIZE*BLOCK_SIZE]; + __shared__ float Bs[BLOCK_SIZE*BLOCK_SIZE]; + float Cvalue = 0; + for (int i = 0; i < max; ++i) { + const float* Asub = &input[BLOCK_SIZE * blockRow * n + BLOCK_SIZE * i ]; + const int original_index = BLOCK_SIZE * blockCol * n + BLOCK_SIZE * i + row * n + col; + As[row * BLOCK_SIZE + col] = Asub[row*n+col]; + Bs[row * BLOCK_SIZE + col] = shift[(original_index)]; + __syncthreads(); + + #pragma unroll + for (int j = 0; j < BLOCK_SIZE ; ++j){ + if(col + blockCol* BLOCK_SIZE< k + && row + blockRow* BLOCK_SIZE< m + && i * BLOCK_SIZE + j < n){ + Cvalue += (As[row * BLOCK_SIZE + j] * Bs[col * BLOCK_SIZE + j]); + } + } + __syncthreads(); + } + if(col + blockCol* BLOCK_SIZE< k && row + blockRow* BLOCK_SIZE< m) Csub[row*k+col] = Cvalue + bias[col + blockCol* BLOCK_SIZE]; + } + __syncthreads(); + } +} + + +void UNOPTIMIZED_LINEAR_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output) +{ + + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + int a1=weight.size(0)/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=input.size(0) / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2= MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + int max =(input.size(1) + BLOCK_SIZE - 1) / BLOCK_SIZE; + AT_DISPATCH_ALL_TYPES(input.type(), "linear unoptimized kernel", ([&] { + GEMM<<>>( + input.data(), + weight.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + weight.size(0), max); + })); +} + +void UNOPTIMIZED_CONV_GPU( + torch::Tensor data_im, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding) +{ + int strides_h; + int strides_w; + if(strides.size() ==1){ + strides_h = strides[0]; + strides_w = strides[0]; + } + else{ + strides_h = strides[0]; + strides_w = strides[1]; + } + int k = shift.size(2) * shift.size(3) * data_im.size(1); + int num_p = output.size(0) * output.size(2) * output.size(3); + + float* data_col; + cudaMalloc(&data_col, num_p * k * sizeof(float)); + + int threads = MAX_THREADS; + int tmp = (k * num_p + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 blk(tmp); + AT_DISPATCH_ALL_TYPES(data_im.type(), "IM2COL cuda", ([&] { + IM2COL<<>>( + k * num_p, + data_im.data(), + data_col, + shift.size(2), + shift.size(3), + data_im.size(1), + output.size(2), + output.size(3), + strides_h, + strides_w, + data_im.size(2), + data_im.size(3), + k, num_p); + })); + int filter_p = output.size(1); + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + int a1=filter_p/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=num_p / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2 = MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + + float *out_col; + int max =(k + BLOCK_SIZE - 1) / BLOCK_SIZE; + cudaMalloc(&out_col, num_p * filter_p * sizeof(float)); + AT_DISPATCH_ALL_TYPES(data_im.type(), "GEMM unoptimized kernel", ([&] { + GEMM<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_p, + filter_p, max); + })); + tmp = (num_p * output.size(1) + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 block1(tmp); + AT_DISPATCH_ALL_TYPES(data_im.type(), "COL2IM cuda", ([&] { + COL2IM<<>>( + num_p * output.size(1), + out_col, + output.data(), + output.size(2), + output.size(3), + output.size(1)); + })); + cudaFree(data_col); + cudaFree(out_col); +} \ No newline at end of file diff --git a/pytorch/deepshift/kernels/cuda/unoptimized_linear.py b/pytorch/deepshift/kernels/cuda/unoptimized_linear.py new file mode 100644 index 0000000..38dc36e --- /dev/null +++ b/pytorch/deepshift/kernels/cuda/unoptimized_linear.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.utils import _pair +from torch.nn import init +import unoptimized_cuda_kernel +import math +import numpy as np +import time + + +class UnoptimizedLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True): + + super(UnoptimizedLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + # You should always register all possible parameters, but the + # optional ones can be None if you want. + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + # start_time = time.time() + out = torch.zeros([input.size(0),self.weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + if self.bias is not None: + unoptimized_cuda_kernel.UNOPTIMIZED_LINEAR(input, self.weight, self.bias,out) + else: + temp = torch.zeros([self.weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + unoptimized_cuda_kernel.UNOPTIMIZED_LINEAR(input, self.weight, temp,out) + # end_time = time.time() + # print("Linear Time:", end_time - start_time ) + return out + + + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) diff --git a/pytorch/deepshift/kernels/kernels.py b/pytorch/deepshift/kernels/kernels.py new file mode 100644 index 0000000..1d802d6 --- /dev/null +++ b/pytorch/deepshift/kernels/kernels.py @@ -0,0 +1,74 @@ +import torch +import time +import torch.nn.functional as F +try: + import deepshift_cuda + import deepshift_cpu +except: + print("Unable to import CPU and/or CUDA bit-wise shift kernels") + +def linear(input, shift, sign, bias=None, conc_weight=None, use_cuda=True): + if(use_cuda): + assert(conc_weight is not None) + # start_time = time.time() + out = torch.zeros([input.size(0), shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) + if bias is not None: + deepshift_cuda.DEEP_SHIFT_LINEAR(input, conc_weight.data, bias, out, conc_weight.base, conc_weight.bits, shift.size(0)) + else: + temp = torch.zeros([shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) + deepshift_cuda.DEEP_SHIFT_LINEAR(input, conc_weight.data, temp, out, conc_weight.base, conc_weight.bits, shift.size(0)) + # end_time = time.time() + # print("Linear Time:", end_time - start_time ) + else: + out = deepshift_cpu.linear_kernel(input.detach().numpy(), shift.detach().numpy(), sign.detach().numpy(), bias.detach().numpy()) + out = torch.Tensor(out) + + return out + +def conv2d(input, shift, sign, bias=None, conc_weight=None, stride=1, padding=0, dilation=1, groups=1, use_cuda=True): + if(use_cuda): + assert(conc_weight is not None) + start_time = time.time() + if len(padding) == 2: + padding = (padding[0], padding[0], padding[1], padding[1]) + else: + padding = padding + input = F.pad(input = input, pad = padding, mode = 'constant', value = 0) + if len(stride) == 1: + strides_h = stride[0] + strides_w = stride[0] + else: + strides_h = stride[0] + strides_w = stride[1] + kernel_size = shift.shape[2:4] + out_height = int((input.size(2) - kernel_size[0]) / strides_h +1) + out_width = int((input.size(3) - kernel_size[1]) / strides_w +1) + out_channels = shift.size(0) + out = torch.zeros([input.size(0), out_channels, out_height, out_width], dtype=torch.int32, device=torch.device('cuda:0')) + + if bias is not None: + deepshift_cuda.DEEP_SHIFT_CONV(input, conc_weight.data, bias, out, stride, padding, kernel_size[0], kernel_size[1], conc_weight.base, conc_weight.bits) + else: + temp = torch.zeros([out_channels], dtype=torch.int32, device=torch.device('cuda:0')) + deepshift_cuda.DEEP_SHIFT_CONV(input, conc_weight.data, temp, out, stride, padding, kernel_size[0], kernel_size[1], conc_weight.base, conc_weight.bits) + # end_time = time.time() + # print("Conv Time:", end_time - start_time ) + + else: + input = F.pad(input = input, pad = padding, mode = 'constant', value = 0) + out = deepshift_cpu.convolution_kernel(input.cpu().detach().numpy(), + shift.cpu().detach().numpy(), + sign.cpu().detach().numpy(), + bias.cpu().detach().numpy(), stride, padding) + out = torch.Tensor(out) + + #print("out - out1: ", out.cpu() - out1.cpu().int()) + + return out + +def compress_sign_and_shift(shift, sign, comp_size, base, bits, row_length, num): + comp_weight = torch.zeros([comp_size], dtype=torch.int32,device = torch.device('cuda:0')) + + deepshift_cuda.COMPRESS_SIGN_SHIFT(shift, sign, comp_weight, base, bits, shift.shape[0], shift.shape[1], shift.shape[2], shift.shape[3], row_length, num) + + return comp_weight \ No newline at end of file diff --git a/pytorch/shift.py b/pytorch/deepshift/modules.py similarity index 64% rename from pytorch/shift.py rename to pytorch/deepshift/modules.py index 3e33511..60f0f64 100644 --- a/pytorch/shift.py +++ b/pytorch/deepshift/modules.py @@ -8,80 +8,41 @@ import numpy as np import time -try: - import shift_kernel - import shift_cuda_kernel -except: - print("Unable to import CPU and/or CUDA bit-wise shift kernels") +import deepshift.utils as utils +import deepshift.kernels +import deepshift.ste as ste log2 = math.log(2) -def round_to_fixed(input, fraction=16, integer=16): - assert integer >= 1, integer - if integer == 1: - return torch.sign(input) - 1 - delta = math.pow(2.0, -(fraction)) - bound = math.pow(2.0, integer-1) - min_val = - bound - max_val = bound - 1 - rounded = torch.floor(input / delta) * delta - - clipped_value = torch.clamp(rounded, min_val, max_val) - return clipped_value - -def get_shift_and_sign(x): - sign = torch.sign(x) - - x_abs = torch.abs(x) - shift = torch.round(torch.log(x_abs) / np.log(2)) - - return shift, sign - -def round_power_of_2(x): - shift, sign = get_shift_and_sign(x) - x_rounded = (2.0 ** shift) * sign - return x_rounded - # Inherit from Function class LinearShiftFunction(Function): # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument - def forward(ctx, input, shift, sign, bias=None, use_kernel=False, use_cuda=True): + def forward(ctx, input, shift, sign, bias=None, conc_weight=None, use_kernel=False, use_cuda=True): fraction_bits = 16 integer_bit = 16 - - sign = sign.clamp(-1,1) - + if use_kernel: input_fixed_point = (input * (2 ** fraction_bits)).int() if bias is not None: bias_fixed_point = (bias * (2 ** fraction_bits)).int() - if(use_cuda): - out = torch.zeros([input.size(0), shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - if bias is not None: - shift_cuda_kernel.linear_shift(input_fixed_point, shift.int(), sign.int(), bias_fixed_point, out) - else: - temp = torch.zeros([shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - shift_cuda_kernel.linear_shift(input_fixed_point, shift.int(), sign.int(), temp, out) - out = out.float() - out = out / (2**fraction_bits) - else: - nn = shift_kernel.linear_kernel(input_fixed_point.detach().numpy(), shift.detach().numpy(), sign.detach().numpy(), bias_fixed_point.detach().numpy()) - out = torch.FloatTensor(nn) - out = out / (2**fraction_bits) + out = deepshift.kernels.linear(input_fixed_point, shift, sign, bias_fixed_point, conc_weight, use_cuda) + out = out.float() + out = out / (2**fraction_bits) else: - input.data = round_to_fixed(input.data, fraction_bits, integer_bit) + sign = sign.clamp(-1,1) + input.data = utils.round_to_fixed(input.data, fraction_bits, integer_bit) if bias is not None: - bias.data = round_to_fixed(bias.data, fraction_bits, integer_bit) + bias.data = utils.round_to_fixed(bias.data, fraction_bits, integer_bit) v = 2**shift.round() * sign.round().sign() out = input.mm(v.t()) if bias is not None: out += bias.unsqueeze(0).expand_as(out) - + ctx.save_for_backward(input, shift, sign, bias, v) return out @@ -113,17 +74,17 @@ def backward(ctx, grad_output): if bias is not None and ctx.needs_input_grad[3]: grad_bias = grad_output.sum(0).squeeze(0) - return grad_input, grad_shift, grad_sign, grad_bias, None, None + return grad_input, grad_shift, grad_sign, grad_bias, None, None, None class LinearShift(nn.Module): def __init__(self, in_features, out_features, bias=True, check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True): - super(LinearShift, self).__init__() self.in_features = in_features self.out_features = out_features self.use_kernel = use_kernel self.check_grad = check_grad self.use_cuda = use_cuda + self.conc_weight = None # nn.Parameter is a special kind of Tensor, that will get # automatically registered as Module's parameter once it's assigned # as an attribute. Parameters and buffers need to be registered, or @@ -159,7 +120,7 @@ def reset_parameters(self): init.uniform_(self.bias, -bound, bound) def forward(self, input): - return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.use_kernel, self.use_cuda) + return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, self.use_kernel, self.use_cuda) def extra_repr(self): # (Optional)Set the extra information about this module. You can test @@ -168,82 +129,43 @@ def extra_repr(self): self.in_features, self.out_features, self.bias is not None ) -# check gradient of linear_shift -linear_shift = LinearShift(20, 30, check_grad=True) -#linear_shift = LinearShiftFunction.apply - -from torch.autograd import gradcheck -# gradcheck takes a tuple of tensors as input, check if your gradient -# evaluated with these tensors are close enough to numerical -# approximations and returns True if they all verify this condition. -data = torch.randn(20,20,dtype=torch.double,requires_grad=True) -weight = torch.randn(30,20,dtype=torch.double,requires_grad=True) -input = (data, weight) -# test = gradcheck(linear_shift, data, eps=1e-6, atol=1e-4) -# print("gradcheck result for linear_shift: ", test) - # Inherit from Function class Conv2dShiftFunction(Function): # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument - def forward(ctx, input, shift, sign, bias=None, stride=1, padding=0, dilation=1, groups=1, use_kernel=False, use_cuda=False): + def forward(ctx, input, shift, sign, bias=None, conc_weight=None, stride=1, padding=0, dilation=1, groups=1, use_kernel=False, use_cuda=False): fraction_bits = 16 integer_bits = 16 - sign = sign.clamp(-1,1) - + # start_time = time.time() if use_kernel: input_fixed_point = (input * (2 ** fraction_bits)).int() if bias is not None: bias_fixed_point = (bias * (2 ** fraction_bits)).int() - - if(use_cuda): - if len(padding) == 2: - padding = (padding[0], padding[0], padding[1], padding[1]) - else: - padding = padding - input_fixed_point = F.pad(input = input_fixed_point, pad = padding, mode = 'constant', value = 0) - if len(stride) == 1: - strides_h = stride[0] - strides_w = stride[0] - else: - strides_h = stride[0] - strides_w = stride[1] - out_height = int((input_fixed_point.size(2) - shift.size(2)) / strides_h +1) - out_width = int((input_fixed_point.size(3) - shift.size(3)) / strides_w +1) - out = torch.zeros([input_fixed_point.size(0), shift.size(0), out_height, out_width], dtype=torch.int32, device=torch.device('cuda:0')) - - if bias is not None: - shift_cuda_kernel.conv2d_shift(input_fixed_point, shift.int(), sign.int(), bias_fixed_point, out, stride, padding) - else: - temp = torch.zeros([shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - shift_cuda_kernel.conv2d_shift(input_fixed_point, shift.int(), sign.int(), temp, out, stride, padding) - out = out.float() - out = out / (2**fraction_bits) else: - input_fixed_point = F.pad(input = input_fixed_point, pad = padding, mode = 'constant', value = 0) - out = shift_kernel.convolution_kernel(input_fixed_point.detach().numpy(), - shift.detach().numpy(), - sign.detach().numpy(), - bias_fixed_point.detach().numpy(), stride, padding) - out = torch.FloatTensor(out) - out = out / (2**fraction_bits) + bias_fixed_point = None + + out = deepshift.kernels.conv2d(input_fixed_point, shift, sign, bias_fixed_point, conc_weight, stride, padding, dilation, groups, use_cuda) + + out = out.float() + out = out / (2**fraction_bits) else: - input.data = round_to_fixed(input.data, fraction_bits, integer_bits) + sign = sign.clamp(-1,1) + input.data = utils.round_to_fixed(input.data, fraction_bits, integer_bits) + if bias is not None: - bias.data = round_to_fixed(bias.data, fraction_bits, integer_bits) + bias.data = utils.round_to_fixed(bias.data, fraction_bits, integer_bits) - with torch.no_grad(): - v = 2**shift.round() * sign.round().sign() - out = F.conv2d(input, v, bias, stride, padding, dilation, groups) + v = 2**shift.round() * sign.round().sign() + out = F.conv2d(input, v, bias, stride, padding, dilation, groups) - ctx.save_for_backward(input, shift, sign, bias, v) - ctx.stride = stride - ctx.padding = padding - ctx.dilation = dilation - ctx.groups = groups + ctx.save_for_backward(input, shift, sign, bias, v) + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups return out @@ -278,7 +200,7 @@ def backward(ctx, grad_output): if bias is not None and ctx.needs_input_grad[3]: grad_bias = grad_output.sum((0,2,3)).squeeze(0) - return grad_input, grad_shift, grad_sign, grad_bias, None, None, None, None, None, None + return grad_input, grad_shift, grad_sign, grad_bias, None, None, None, None, None, None, None, None class _ConvNdShift(nn.Module): @@ -330,7 +252,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, def reset_parameters(self): self.shift.data.uniform_(-10, -1) # (-0.1, 0.1) self.sign.data.uniform_(-1, 1) # (-0.1, 0.1) - if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.shift) bound = 1 / math.sqrt(fan_in) @@ -354,29 +275,57 @@ def extra_repr(self): class Conv2dShift(_ConvNdShift): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, - bias=True, padding_mode='zeros', - check_grad=False, freeze_sign=False, use_kernel=False,use_cuda =True): + bias=True, padding_mode='zeros', check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) self.use_kernel = use_kernel self.use_cuda = use_cuda + self.conc_weight = None super(Conv2dShift, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, - False, _pair(0), groups, bias, padding_mode, + False, _pair(0), groups, bias, padding_mode, check_grad, freeze_sign) #@weak_script_method def forward(self, input): + shift_rounded = ste.round(self.shift) + sign_rounded_signed = ste.sign(ste.round(self.sign)) + weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) # 2**utils.stochastic_rounding(shift) * sign.round().sign() + input_fixed_point = ste.round_fixed_point(input) + if self.bias is not None: + bias_fixed_point = ste.round_fixed_point(self.bias) + else: + bias_fixed_point = None + + if self.padding_mode == 'circular': + expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, + (self.padding[0] + 1) // 2, self.padding[0] // 2) + + input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') + padding = _pair(0) + else: + input_padded = input_fixed_point + padding = self.padding + + if self.use_kernel: + return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, bias_fixed_point, self.conc_weight, + self.stride, padding, self.dilation, self.groups, + self.use_kernel, self.use_cuda) + else: + return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, + self.stride, padding, self.dilation, self.groups) + ''' if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) return Conv2dShiftFunction.apply(F.pad(input, expanded_padding, mode='circular'), - self.shift, self.sign, self.bias, self.stride, - _pair(0), self.dilation, self.groups, - self.use_kernel, self.use_cuda) + self.shift, self.sign, self.bias, self.conc_weight, + self.stride, _pair(0), self.dilation, self.groups, + self.use_kernel, self.use_cuda, self.conc_weight) else: - return Conv2dShiftFunction.apply(input, self.shift, self.sign, self.bias, self.stride, - self.padding, self.dilation, self.groups, + return Conv2dShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, + self.stride, self.padding, self.dilation, self.groups, self.use_kernel, self.use_cuda) + ''' diff --git a/pytorch/shift_q.py b/pytorch/deepshift/modules_q.py similarity index 56% rename from pytorch/shift_q.py rename to pytorch/deepshift/modules_q.py index e658962..56e5b7c 100644 --- a/pytorch/shift_q.py +++ b/pytorch/deepshift/modules_q.py @@ -7,99 +7,9 @@ import math import numpy as np import time -from shift import round_to_fixed, get_shift_and_sign, round_power_of_2 - -try: - import shift_kernel - import shift_cuda_kernel -except: - print("Unable to import CPU and/or CUDA bit-wise shift kernels") - -class STERoundPowerOf2(Function): - @staticmethod - def forward(ctx, input): - return round_power_of_2(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -class STERoundFixedPoint(Function): - @staticmethod - def forward(ctx, input): - return round_to_fixed(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -def ste_round_fixed_point(input): - return STERoundFixedPoint.apply(input) - -def ste_round_power_of_2(input): - return STERoundPowerOf2.apply(input) - -class STERoundFunction(Function): - @staticmethod - def forward(ctx, input): - return torch.round(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -def ste_round(input): - return STERoundFunction.apply(input) - -class STESignFunction(Function): - @staticmethod - def forward(ctx, input): - return torch.sign(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -def ste_sign(input): - return STESignFunction.apply(input) - -class STELogFunction(Function): - @staticmethod - def forward(ctx, input): - return torch.log(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -def ste_log(input): - return STELogFunction.apply(input) - -class STEDivFunction(Function): - @staticmethod - def forward(ctx, input, const): - return torch.div(input, const) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - -def ste_div(input, const): - return STEDivFunction.apply(input, const) - - -class STEAbsFunction(Function): - @staticmethod - def forward(ctx, input): - return torch.abs(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -def ste_abs(input): - return STEAbsFunction.apply(input) - +import deepshift.utils as utils +import deepshift.kernels +import deepshift.ste as ste # Inherit from Function class LinearShiftQFunction(Function): @@ -107,30 +17,20 @@ class LinearShiftQFunction(Function): # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument - def forward(ctx, input, weight, bias=None, use_kernel=False, use_cuda=True): + def forward(ctx, input, weight, bias=None, conc_weight=None, use_kernel=False, use_cuda=True): fraction_bits = 16 integer_bit = 16 - shift, sign = get_shift_and_sign(weight) + shift, sign = utils.get_shift_and_sign(weight) if use_kernel: input_fixed_point = (input * (2 ** fraction_bits)).int() if bias is not None: bias_fixed_point = (bias * (2 ** fraction_bits)).int() - if(use_cuda): - out = torch.zeros([input.size(0), shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - if bias is not None: - shift_cuda_kernel.linear_shift(input_fixed_point, shift.int(), sign.int(), bias_fixed_point, out) - else: - temp = torch.zeros([shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - shift_cuda_kernel.linear_shift(input_fixed_point, shift.int(), sign.int(), temp, out) - out = out.float() - out = out / (2**fraction_bits) - else: - nn = shift_kernel.linear_kernel(input_fixed_point.detach().numpy(), shift.detach().numpy(), sign.detach().numpy(), bias_fixed_point.detach().numpy()) - out = torch.FloatTensor(nn) - out = out / (2**fraction_bits) + out = deepshift.kernels.linear(input_fixed_point, shift, sign, bias_fixed_point, conc_weight, use_cuda) + out = out.float() + out = out / (2**fraction_bits) else: input.data = round_to_fixed(input.data, fraction_bits, integer_bit) if bias is not None: @@ -141,7 +41,7 @@ def forward(ctx, input, weight, bias=None, use_kernel=False, use_cuda=True): if bias is not None: out += bias.unsqueeze(0).expand_as(out) - ctx.save_for_backward(input, weight_s, bias) + ctx.save_for_backward(input, weight_s, bias) return out @@ -178,6 +78,7 @@ def __init__(self, in_features, out_features, bias=True, check_grad=False, use_k self.use_kernel = use_kernel self.check_grad = check_grad self.use_cuda = use_cuda + self.conc_weight = None # nn.Parameter is a special kind of Tensor, that will get # automatically registered as Module's parameter once it's assigned # as an attribute. Parameters and buffers need to be registered, or @@ -211,7 +112,21 @@ def reset_parameters(self): init.uniform_(self.bias, -bound, bound) def forward(self, input): - return LinearShiftQFunction.apply(input, self.weight, self.bias, self.use_kernel, self.use_cuda) + weight_q = ste.round_power_of_2(self.weight) + input_fixed_point = ste.round_fixed_point(input) + if self.bias is not None: + bias_fixed_point = ste.round_fixed_point(self.bias) + else: + bias_fixed_point = None + + if self.use_kernel: + return LinearShiftQFunction.apply(input_fixed_point, weight_q, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda) + else: + out = input_fixed_point.mm(weight_q.t()) + if self.bias is not None: + out += self.bias.unsqueeze(0).expand_as(out) + + return out def extra_repr(self): # (Optional)Set the extra information about this module. You can test @@ -220,81 +135,38 @@ def extra_repr(self): self.in_features, self.out_features, self.bias is not None ) -# check gradient of linear_shift -linear_shift = LinearShiftQ(20, 30, check_grad=True) -#linear_shift = LinearShiftFunction.apply - -from torch.autograd import gradcheck -# gradcheck takes a tuple of tensors as input, check if your gradient -# evaluated with these tensors are close enough to numerical -# approximations and returns True if they all verify this condition. -data = torch.randn(20,20,dtype=torch.double,requires_grad=True) -weight = torch.randn(30,20,dtype=torch.double,requires_grad=True) -input = (data, weight) -# test = gradcheck(linear_shift, data, eps=1e-6, atol=1e-4) -# print("gradcheck result for linear_shift: ", test) - # Inherit from Function class Conv2dShiftQFunction(Function): # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument - def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, use_kernel=False, use_cuda=False): + def forward(ctx, input, weight, bias=None, conc_weight=None, stride=1, padding=0, dilation=1, groups=1, use_kernel=False, use_cuda=False): fraction_bits = 16 integer_bits = 16 - shift, sign = get_shift_and_sign(weight) + shift, sign = utils.get_shift_and_sign(weight) if use_kernel: input_fixed_point = (input * (2 ** fraction_bits)).int() if bias is not None: bias_fixed_point = (bias * (2 ** fraction_bits)).int() - - if(use_cuda): - if len(padding) == 2: - padding = (padding[0], padding[0], padding[1], padding[1]) - else: - padding = padding - input_fixed_point = F.pad(input = input_fixed_point, pad = padding, mode = 'constant', value = 0) - if len(stride) == 1: - strides_h = stride[0] - strides_w = stride[0] - else: - strides_h = stride[0] - strides_w = stride[1] - out_height = int((input_fixed_point.size(2) - shift.size(2)) / strides_h +1) - out_width = int((input_fixed_point.size(3) - shift.size(3)) / strides_w +1) - out = torch.zeros([input_fixed_point.size(0), shift.size(0), out_height, out_width], dtype=torch.int32, device=torch.device('cuda:0')) - - if bias is not None: - shift_cuda_kernel.conv2d_shift(input_fixed_point, shift.int(), sign.int(), bias_fixed_point, out, stride, padding) - else: - temp = torch.zeros([shift.size(0)], dtype=torch.int32, device=torch.device('cuda:0')) - shift_cuda_kernel.conv2d_shift(input_fixed_point, shift.int(), sign.int(), temp, out, stride, padding) - out = out.float() - out = out / (2**fraction_bits) else: - input_fixed_point = F.pad(input = input_fixed_point, pad = padding, mode = 'constant', value = 0) - out = shift_kernel.convolution_kernel(input_fixed_point.detach().numpy(), - shift.detach().numpy(), - sign.detach().numpy(), - bias_fixed_point.detach().numpy(), stride, padding) - out = torch.FloatTensor(out) - out = out / (2**fraction_bits) - else: - input.data = round_to_fixed(input.data, fraction_bits, integer_bits) - if bias is not None: - bias.data = round_to_fixed(bias.data, fraction_bits, integer_bits) + bias_fixed_point = None + out = deepshift.kernels.conv2d(input_fixed_point, shift, sign, bias_fixed_point, conc_weight, stride, padding, dilation, groups, use_cuda) + + out = out.float() + out = out / (2**fraction_bits) + else: weight_s = (2.0 ** shift) * sign out = F.conv2d(input, weight_s, bias, stride, padding, dilation, groups) - ctx.save_for_backward(input, weight_s, bias) - ctx.stride = stride - ctx.padding = padding - ctx.dilation = dilation - ctx.groups = groups + ctx.save_for_backward(input, weight_s, bias) + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups return out @@ -387,13 +259,14 @@ class Conv2dShiftQ(_ConvNdShiftQ): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', - check_grad=False, use_kernel=False,use_cuda =True): + check_grad=False, use_kernel=False, use_cuda =True): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) self.use_kernel = use_kernel self.use_cuda = use_cuda + self.conc_weight = None super(Conv2dShiftQ, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, @@ -401,19 +274,27 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, #@weak_script_method def forward(self, input): - weight_q = ste_round_power_of_2(self.weight) - input_fxied_point = ste_round_fixed_point(input) + weight_q = ste.round_power_of_2(self.weight) + input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: - bias_fixed_point = ste_round_fixed_point(self.bias) + bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) - return torch.nn.functional.conv2d(F.pad(input_fxied_point, expanded_padding, mode='circular'), - weight_q, bias_fixed_point, self.stride, - _pair(0), self.dilation, self.groups) + + input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') + padding = _pair(0) + else: + input_padded = input_fixed_point + padding = self.padding + + if self.use_kernel: + return Conv2dShiftQFunction.apply(input_padded, weight_q, bias_fixed_point, self.conc_weight, + self.stride, padding, self.dilation, self.groups, + self.use_kernel, self.use_cuda) else: - return torch.nn.functional.conv2d(input_fxied_point, weight_q, bias_fixed_point, self.stride, - self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input_padded, weight_q, bias_fixed_point, + self.stride, padding, self.dilation, self.groups) diff --git a/pytorch/deepshift/ste.py b/pytorch/deepshift/ste.py new file mode 100644 index 0000000..d31e649 --- /dev/null +++ b/pytorch/deepshift/ste.py @@ -0,0 +1,90 @@ +import torch +from torch.autograd import Function +import deepshift.utils as utils + +class RoundPowerOf2(Function): + @staticmethod + def forward(ctx, input): + return utils.round_power_of_2(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def round_power_of_2(input): + return RoundPowerOf2.apply(input) + +class RoundFixedPoint(Function): + @staticmethod + def forward(ctx, input): + return utils.round_to_fixed(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def round_fixed_point(input): + return RoundFixedPoint.apply(input) + +class RoundFunction(Function): + @staticmethod + def forward(ctx, input): + return torch.round(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def round(input): + return RoundFunction.apply(input) + +class SignFunction(Function): + @staticmethod + def forward(ctx, input): + return torch.sign(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def sign(input): + return SignFunction.apply(input) + +class LogFunction(Function): + @staticmethod + def forward(ctx, input): + return torch.log(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def log(input): + return LogFunction.apply(input) + +class UnsymmetricGradMulFunction(Function): + @staticmethod + def forward(ctx, input1, input2): + ctx.save_for_backward(input1, input2) + return torch.mul(input1, input2) + + @staticmethod + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + return grad_output*input2, grad_output + +def unsym_grad_mul(input1, input2): + return UnsymmetricGradMulFunction.apply(input1, input2) + + +class AbsFunction(Function): + @staticmethod + def forward(ctx, input): + return torch.abs(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def abs(input): + return AbsFunction.apply(input) \ No newline at end of file diff --git a/pytorch/deepshift/utils.py b/pytorch/deepshift/utils.py new file mode 100644 index 0000000..e5ebbb8 --- /dev/null +++ b/pytorch/deepshift/utils.py @@ -0,0 +1,67 @@ +import torch +import numpy as np +import math + +import deepshift.kernels + +def round_to_fixed(input, fraction=16, integer=16): + assert integer >= 1, integer + if integer == 1: + return torch.sign(input) - 1 + delta = math.pow(2.0, -(fraction)) + bound = math.pow(2.0, integer-1) + min_val = - bound + max_val = bound - 1 + rounded = torch.floor(input / delta) * delta + + clipped_value = torch.clamp(rounded, min_val, max_val) + return clipped_value + +def get_shift_and_sign(x): + sign = torch.sign(x) + + x_abs = torch.abs(x) + shift = torch.round(torch.log(x_abs) / np.log(2)) + + return shift, sign + +def round_power_of_2(x): + shift, sign = get_shift_and_sign(x) + x_rounded = (2.0 ** shift) * sign + return x_rounded + +class ConcWeight(): + def __init__(self, data=None, base=0, bits=8): + self.data = data + self.base = base + self.bits = bits + +##concatenate shift and sign together +def compress_bits(shift, sign): + conc_weight = ConcWeight() + + if len(shift.shape) == 2: + shift = shift.unsqueeze(-1).unsqueeze(-1) + + # if sign is ternary, then use a big shift value that is equivalent to multiplying by zero + zero_sign_indices = (sign == 0).nonzero() + shift[zero_sign_indices] = -32 + sign[zero_sign_indices] = +1 + + conc_weight.bits = math.ceil(torch.log( - torch.min(shift) + 1)/ np.log(2)) + # treat shift to the right as the default + shift = shift * -1 + minimum = int(torch.min(shift)) + if minimum < 0: + conc_weight.base = minimum + shift = shift - minimum + else: + conc_weight.base = 0 + + num = int(32 / (conc_weight.bits + 1)) + row_length = int((shift.shape[1] * shift.shape[2] * shift.shape[3] + num -1) / num ) + size = row_length * shift.shape[0] + + conc_weight.data = deepshift.kernels.compress_sign_and_shift(shift.int().cuda(), sign.int().cuda(), size, conc_weight.base, conc_weight.bits, row_length, num) + + return conc_weight \ No newline at end of file diff --git a/pytorch/imagenet.py b/pytorch/imagenet.py index c4e47f4..9cefb08 100644 --- a/pytorch/imagenet.py +++ b/pytorch/imagenet.py @@ -26,7 +26,9 @@ import optim import copy -from convert_to_shift import convert_to_shift, count_layer_type, round_shift_weights +from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type +from unoptimized.convert import convert_to_unoptimized + import customized_models default_model_names = sorted(name for name in models.__dict__ @@ -57,8 +59,8 @@ help='path to file to load its weights (default: none)') parser.add_argument('-s', '--shift-depth', type=int, default=0, help='how many layers to convert to shift') -parser.add_argument('-st', '--shift-type', default='Q', choices=['Q', 'PS'], - help='type of DeepShift method for training and representing weights (default: Q)') +parser.add_argument('-st', '--shift-type', default='PS', choices=['Q', 'PS'], + help='type of DeepShift method for training and representing weights (default: PS)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', @@ -94,7 +96,7 @@ help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='only evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', default=True, type=lambda x:bool(distutils.util.strtobool(x)), +parser.add_argument('--pretrained', dest='pretrained', default=False, type=lambda x:bool(distutils.util.strtobool(x)), help='use pre-trained model') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') @@ -116,7 +118,7 @@ parser.add_argument('--save-model', default=True, type=lambda x:bool(distutils.util.strtobool(x)), help='For Saving the current Model (default: True)') -parser.add_argument('--print-weights', default=True, type=lambda x:bool(distutils.util.strtobool(x)), +parser.add_argument('--print-weights', default=False, type=lambda x:bool(distutils.util.strtobool(x)), help='For printing the weights of Model (default: True)') parser.add_argument('--desc', type=str, default=None, help='description to append to model directory name') @@ -129,6 +131,9 @@ def main(): args = parser.parse_args() + if(args.evaluate is False and args.use_kernel is True): + raise ValueError('Our custom kernel currently supports inference only, not training.') + if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) @@ -199,6 +204,8 @@ def main_worker(gpu, ngpus_per_node, args): print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() + model_rounded = None + if args.weights: saved_weights = torch.load(args.weights) if isinstance(saved_weights, nn.Module): @@ -221,6 +228,8 @@ def main_worker(gpu, ngpus_per_node, args): if args.shift_depth > 0: model, _ = convert_to_shift(model, args.shift_depth, args.shift_type, convert_weights = args.pretrained or args.weights, freeze_sign = (args.lr_sign == 0), use_kernel = args.use_kernel) + elif args.use_kernel and args.shift_depth == 0: + model = convert_to_unoptimized(model) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor @@ -350,12 +359,12 @@ def main_worker(gpu, ngpus_per_node, args): cudnn.benchmark = True - model_tmp_copy = copy.deepcopy(model) # we noticed calling summary() on original model degrades it's accuracy. So we will call summary() on a copy of the model - try: - summary(model_tmp_copy, input_size=(3, 224, 224)) - print("WARNING: The summary function reports duplicate parameters for multi-GPU case") - except: - print("WARNING: Unable to obtain summary of model") + # model_tmp_copy = copy.deepcopy(model) # we noticed calling summary() on original model degrades it's accuracy. So we will call summary() on a copy of the model + # try: + # summary(model_tmp_copy, input_size=(3, 224, 224)) + # print("WARNING: The summary function reports duplicate parameters for multi-GPU case") + # except: + # print("WARNING: Unable to obtain summary of model") # name model sub-directory "shift_all" if all layers are converted to shift layers conv2d_layers_count = count_layer_type(model, nn.Conv2d) @@ -384,16 +393,16 @@ def main_worker(gpu, ngpus_per_node, args): for arg, value in sorted(vars(args).items()): command_args_file.write(arg + ": " + str(value) + "\n") - with open(os.path.join(model_dir, 'model_summary.txt'), 'w') as summary_file: - with redirect_stdout(summary_file): - try: - # TODO: make this summary function deal with parameters that are not named "weight" and "bias" - summary(model_tmp_copy, input_size=(3, 224, 224)) - print("WARNING: The summary function reports duplicate parameters for multi-GPU case") - except: - print("WARNING: Unable to obtain summary of model") + # with open(os.path.join(model_dir, 'model_summary.txt'), 'w') as summary_file: + # with redirect_stdout(summary_file): + # try: + # # TODO: make this summary function deal with parameters that are not named "weight" and "bias" + # summary(model_tmp_copy, input_size=(3, 224, 224)) + # print("WARNING: The summary function reports duplicate parameters for multi-GPU case") + # except: + # print("WARNING: Unable to obtain summary of model") - del model_tmp_copy # to save memory + # del model_tmp_copy # to save memory # Data loading code traindir = os.path.join(args.data, 'train') @@ -488,9 +497,10 @@ def main_worker(gpu, ngpus_per_node, args): if is_best: try: if (args.save_model): - torch.save(model.state_dict(), os.path.join(model_dir, "weights.pth")) - torch.save(optimizer.state_dict(), os.path.join(model_dir, "optimizer.pth")) - torch.save(model, os.path.join(model_dir, "model.pth")) + model_rounded = round_shift_weights(model, clone=True) + + torch.save(model_rounded.state_dict(), os.path.join(model_dir, "weights.pth")) + torch.save(model_rounded, os.path.join(model_dir, "model.pth")) except: print("WARNING: Unable to save model.pth") @@ -507,14 +517,17 @@ def main_worker(gpu, ngpus_per_node, args): print("Total Time:", end_time - start_time ) if (args.print_weights): + if(model_rounded is None): + model_rounded = round_shift_weights(model, clone=True) + with open(os.path.join(model_dir, 'weights_log.txt'), 'w') as weights_log_file: with redirect_stdout(weights_log_file): # Log model's state_dict print("Model's state_dict:") # TODO: Use checkpoint above - for param_tensor in model.state_dict(): - print(param_tensor, "\t", model.state_dict()[param_tensor].size()) - print(model.state_dict()[param_tensor]) + for param_tensor in model_rounded.state_dict(): + print(param_tensor, "\t", model_rounded.state_dict()[param_tensor].size()) + print(model_rounded.state_dict()[param_tensor]) print("") diff --git a/pytorch/install_kernels.sh b/pytorch/install_kernels.sh new file mode 100755 index 0000000..88399de --- /dev/null +++ b/pytorch/install_kernels.sh @@ -0,0 +1,12 @@ +#/usr/bin/sh +cd ./unoptimized/kernels/cuda +python setup.py install +cd - + +cd ./deepshift/kernels/cpu +python setup.py install +cd - + +cd ./deepshift/kernels/cuda +python setup.py install +cd - diff --git a/pytorch/mnist.py b/pytorch/mnist.py index 89632e8..df89a2e 100644 --- a/pytorch/mnist.py +++ b/pytorch/mnist.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.optim import optim from torchvision import datasets, transforms import csv @@ -15,8 +14,9 @@ import mnist import copy -import shift -from convert_to_shift import convert_to_shift, round_shift_weights +import deepshift +from deepshift.convert import convert_to_shift, round_shift_weights +from unoptimized.convert import convert_to_unoptimized class LinearMNIST(nn.Module): def __init__(self): @@ -81,9 +81,7 @@ def test(args, model, device, test_loader, loss_fn): test_loss += loss_fn(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() - test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) @@ -104,8 +102,8 @@ def main(): help='path to file to load its weights (default: none)') parser.add_argument('--shift-depth', type=int, default=0, help='how many layers to convert to shift') - parser.add_argument('-st', '--shift-type', default='Q', choices=['Q', 'PS'], - help='type of DeepShift method for training and representing weights (default: Q)') + parser.add_argument('-st', '--shift-type', default='PS', choices=['Q', 'PS'], + help='type of DeepShift method for training and representing weights (default: PS)') parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', help='number of data loading workers (default: 1)') parser.add_argument('--batch-size', type=int, default=64, metavar='N', @@ -143,6 +141,9 @@ def main(): help='whether using custom shift kernel') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() + + if(args.evaluate is False and args.use_kernel is True): + raise ValueError('Our custom kernel currently supports inference only, not training.') torch.manual_seed(args.seed) @@ -180,8 +181,10 @@ def main(): model = ConvMNIST().to(device) if args.pretrained: - model.load_state_dict(torch.load("./models/mnist/simple_" + args.type + "/shift_0/weights.pt")) + model.load_state_dict(torch.load("./models/mnist/simple_" + args.type + "/shift_0/weights.pth")) model = model.to(device) + + model_rounded = None if args.weights: saved_weights = torch.load(args.weights) @@ -197,6 +200,12 @@ def main(): if args.shift_depth > 0: model, _ = convert_to_shift(model, args.shift_depth, args.shift_type, convert_all_linear=(args.type != 'linear'), convert_weights=True, use_kernel = args.use_kernel, use_cuda = use_cuda) model = model.to(device) + elif args.use_kernel and args.shift_depth == 0: + model = convert_to_unoptimized(model) + model = model.to(device) + elif args.use_kernel and args.shift_depth == 0: + model = convert_to_unoptimized(model) + model = model.to(device) loss_fn = F.cross_entropy # F.nll_loss # define optimizer @@ -265,13 +274,12 @@ def main(): except: print("WARNING: Unable to obtain summary of model") - del model_tmp_copy + # del model_tmp_copy start_time = time.time() if args.evaluate: test_loss, correct = test(args, model, device, test_loader, loss_fn) test_log = [(test_loss, correct/1e4)] - with open(os.path.join(model_dir, "test_log.csv"), "w") as test_log_file: test_log_csv = csv.writer(test_log_file) test_log_csv.writerow(['test_loss', 'correct']) @@ -300,24 +308,29 @@ def main(): train_log_csv.writerow(['epoch', 'train_loss', 'test_loss', 'test_accuracy']) train_log_csv.writerows(train_log) - if (args.save_model): - torch.save(model, os.path.join(model_dir, "model.pt")) - torch.save(model.state_dict(), os.path.join(model_dir, "weights.pt")) - torch.save(optimizer.state_dict(), os.path.join(model_dir, "optimizer.pt")) + if (args.save_model): + model_rounded = round_shift_weights(model, clone=True) + + torch.save(model_rounded, os.path.join(model_dir, "model.pth")) + torch.save(model_rounded.state_dict(), os.path.join(model_dir, "weights.pth")) end_time = time.time() print("Total Time:", end_time - start_time ) if (args.print_weights): + if(model_rounded is None): + model_rounded = round_shift_weights(model, clone=True) + with open(os.path.join(model_dir, 'weights_log.txt'), 'w') as weights_log_file: with redirect_stdout(weights_log_file): # Log model's state_dict print("Model's state_dict:") # TODO: Use checkpoint above - for param_tensor in model.state_dict(): - print(param_tensor, "\t", model.state_dict()[param_tensor].size()) - print(model.state_dict()[param_tensor]) + for param_tensor in model_rounded.state_dict(): + print(param_tensor, "\t", model_rounded.state_dict()[param_tensor].size()) + print(model_rounded.state_dict()[param_tensor]) print("") if __name__ == '__main__': main() + torch.cuda.empty_cache() diff --git a/pytorch/test.py b/pytorch/test.py deleted file mode 100644 index c1cdf41..0000000 --- a/pytorch/test.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -import torchvision - -N, C, H, W = 256, 3, 224, 224 #256, 3, 224, 224 -k, c, h, w = 3, 3, 7, 7 #64, 3, 7, 7 -assert c==C - -X = torch.rand((N, C, H, W)) -A = torch.rand((k, c, h, w)) -B = torch.rand((k, c, h, w)) - -conv = torch.nn.Conv2d(c, k, (h,w), stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros') -conv.weight.data = A+B -Y = conv(X) - -conv1 = torch.nn.Conv2d(c, k, (h,w), stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros') -conv2 = torch.nn.Conv2d(c, k, (h,w), stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros') -conv1.weight.data = A -conv2.weight.data = B -Y_new = conv1(X) + conv2(X) - -MSE = torch.sum((Y - Y_new)**2) -print(MSE.detach().numpy()) \ No newline at end of file diff --git a/pytorch/unoptimized/convert.py b/pytorch/unoptimized/convert.py new file mode 100644 index 0000000..7a9ea99 --- /dev/null +++ b/pytorch/unoptimized/convert.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import numpy as np + +from unoptimized.modules.linear import UnoptimizedLinear +from unoptimized.modules.conv import UnoptimizedConv2d + +def convert_to_unoptimized(model): + for name, module in model._modules.items(): + if len(list(module.children())) > 0: + # recurse + model._modules[name] = convert_to_unoptimized(model=module) + if type(module) == nn.Linear: + linear = module + unoptimized_linear = UnoptimizedLinear(module.in_features, module.out_features, module.bias is not None) + unoptimized_linear.weight = linear.weight + unoptimized_linear.bias = linear.bias + + model._modules[name] = unoptimized_linear + if type(module) == nn.Conv2d: + conv2d = module + unoptimized_conv = UnoptimizedConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, + module.padding, module.dilation, module.groups, + module.bias is not None, module.padding_mode) + unoptimized_conv.bias = conv2d.bias + unoptimized_conv.weight = conv2d.weight + + model._modules[name] = unoptimized_conv + + return model + + +if __name__ == '__main__': + # this test will be run if you type in the command: + # > python convert_to_unoptimized + import torchvision.models as models + model = models.__dict__['resnet18'](pretrained=True) + model = model.to("cuda:0") + input = torch.rand((32, 3, 224, 224)).to("cuda:0") + output1 = model(input) + + + model = convert_to_unoptimized(model).to("cuda:0") + output2 = model(input) + + max_error = torch.max(torch.abs(output1 - output2)) + print(max_error.detach().cpu().numpy()) + diff --git a/pytorch/unoptimized/kernels/__init__.py b/pytorch/unoptimized/kernels/__init__.py new file mode 100644 index 0000000..2cdd6b4 --- /dev/null +++ b/pytorch/unoptimized/kernels/__init__.py @@ -0,0 +1 @@ +from .kernels import * \ No newline at end of file diff --git a/pytorch/cuda_kernel/setup.py b/pytorch/unoptimized/kernels/cuda/setup.py similarity index 59% rename from pytorch/cuda_kernel/setup.py rename to pytorch/unoptimized/kernels/cuda/setup.py index 6d6e675..97a48e0 100644 --- a/pytorch/cuda_kernel/setup.py +++ b/pytorch/unoptimized/kernels/cuda/setup.py @@ -2,11 +2,11 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( - name='shift_cuda_kernel', + name='unoptimized_cuda', ext_modules=[ - CUDAExtension('shift_cuda_kernel', [ - 'shift_cuda.cpp', - 'shift_cuda_kernel.cu', + CUDAExtension('unoptimized_cuda', [ + 'unoptimized_cuda.cpp', + 'unoptimized.cu', ],extra_compile_args=['-O3']) ], cmdclass={ diff --git a/pytorch/unoptimized/kernels/cuda/unoptimized.cu b/pytorch/unoptimized/kernels/cuda/unoptimized.cu new file mode 100644 index 0000000..28a10d0 --- /dev/null +++ b/pytorch/unoptimized/kernels/cuda/unoptimized.cu @@ -0,0 +1,219 @@ + +#include +#include +#include +#include +#include +#define BLOCK_SIZE 16 +#define MAX_THREADS 1024 +#define MAX_BLOCKS 65535 +__global__ void IM2COL( + const int total, + const float* __restrict__ im, + float* __restrict__ col, + const int filter_height, + const int filter_width, + const int input_features, + const int out_height, + const int out_width, + const int strides_h, + const int strides_w, + const int in_height, + const int in_width, + const int k, const int num) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x) { + const int h = index / k; + const int w = index % k; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + const int ic = w / (filter_height * filter_width); + const int hh_f = (w % (filter_height * filter_width)) / filter_width; + const int ww_f = (w % (filter_height * filter_width)) % filter_width; + + col[index] = im[ww_f + strides_w * w_out + + (hh_f + strides_h * h_out) * in_width + + ic * in_width * in_height + + n * in_width * in_height * input_features]; + } +} + + +__global__ void COL2IM( + const int total, + const float* __restrict__ col, + float* __restrict__ im, + const int out_height, + const int out_width, + const int oc) +{ + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index = index + gridDim.x * blockDim.x){ + const int h = index / oc; + const int w = index % oc; + const int n = h / (out_height * out_width); + const int out_idx = h % (out_height * out_width); + const int h_out = out_idx / out_width; + const int w_out = out_idx % out_width; + im[w_out + h_out * out_width + out_width * out_height * w + n * oc * out_width * out_height] = col[index]; + } +} + +__global__ void GEMM( + const float* __restrict__ input, + const float* __restrict__ shift, + const float* __restrict__ bias, + float* __restrict__ output, + const int n, + const int m, + const int k, + const int max) +{ + + const int row = threadIdx.y; + const int col = threadIdx.x; + for(int blockRow = blockIdx.y;blockDim.y * blockRow < m; blockRow = blockRow + gridDim.y){ + for(int blockCol = blockIdx.x;blockDim.x * blockCol < k; blockCol = blockCol + gridDim.x){ + float* Csub = &output[BLOCK_SIZE * k * blockRow + BLOCK_SIZE * blockCol]; + __shared__ float As[BLOCK_SIZE*BLOCK_SIZE]; + __shared__ float Bs[BLOCK_SIZE*BLOCK_SIZE]; + float Cvalue = 0; + for (int i = 0; i < max; ++i) { + const float* Asub = &input[BLOCK_SIZE * blockRow * n + BLOCK_SIZE * i ]; + const int original_index = BLOCK_SIZE * blockCol * n + BLOCK_SIZE * i + row * n + col; + As[row * BLOCK_SIZE + col] = Asub[row*n+col]; + Bs[row * BLOCK_SIZE + col] = shift[(original_index)]; + __syncthreads(); + + #pragma unroll + for (int j = 0; j < BLOCK_SIZE ; ++j){ + if(col + blockCol* BLOCK_SIZE< k + && row + blockRow* BLOCK_SIZE< m + && i * BLOCK_SIZE + j < n){ + Cvalue += (As[row * BLOCK_SIZE + j] * Bs[col * BLOCK_SIZE + j]); + } + } + __syncthreads(); + } + if(col + blockCol* BLOCK_SIZE< k && row + blockRow* BLOCK_SIZE< m) Csub[row*k+col] = Cvalue + bias[col + blockCol* BLOCK_SIZE]; + } + __syncthreads(); + } +} + + +void UNOPTIMIZED_LINEAR_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output) +{ + + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + int a1=weight.size(0)/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=input.size(0) / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2= MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + int max =(input.size(1) + BLOCK_SIZE - 1) / BLOCK_SIZE; + AT_DISPATCH_ALL_TYPES(input.type(), "linear unoptimized kernel", ([&] { + GEMM<<>>( + input.data(), + weight.data(), + bias.data(), + output.data(), + input.size(1), + input.size(0), + weight.size(0), max); + })); +} + +void UNOPTIMIZED_CONV_GPU( + torch::Tensor data_im, + torch::Tensor shift, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding) +{ + int strides_h; + int strides_w; + if(strides.size() ==1){ + strides_h = strides[0]; + strides_w = strides[0]; + } + else{ + strides_h = strides[0]; + strides_w = strides[1]; + } + int k = shift.size(2) * shift.size(3) * data_im.size(1); + int num_p = output.size(0) * output.size(2) * output.size(3); + + float* data_col; + cudaMalloc(&data_col, num_p * k * sizeof(float)); + + int threads = MAX_THREADS; + int tmp = (k * num_p + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 blk(tmp); + AT_DISPATCH_ALL_TYPES(data_im.type(), "IM2COL cuda", ([&] { + IM2COL<<>>( + k * num_p, + data_im.data(), + data_col, + shift.size(2), + shift.size(3), + data_im.size(1), + output.size(2), + output.size(3), + strides_h, + strides_w, + data_im.size(2), + data_im.size(3), + k, num_p); + })); + int filter_p = output.size(1); + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + int a1=filter_p/ BLOCK_SIZE + 1; + if(a1> MAX_BLOCKS){ + a1 = MAX_BLOCKS; + } + int a2=num_p / BLOCK_SIZE + 1; + if(a2> MAX_BLOCKS) { + a2 = MAX_BLOCKS; + } + dim3 gridDim( a1, a2); + + float *out_col; + int max =(k + BLOCK_SIZE - 1) / BLOCK_SIZE; + cudaMalloc(&out_col, num_p * filter_p * sizeof(float)); + AT_DISPATCH_ALL_TYPES(data_im.type(), "GEMM unoptimized kernel", ([&] { + GEMM<<>>( + data_col, + shift.data(), + bias.data(), + out_col, + k, + num_p, + filter_p, max); + })); + tmp = (num_p * output.size(1) + threads -1) / threads; + tmp = (tmp > MAX_BLOCKS) ? MAX_BLOCKS: tmp; + const dim3 block1(tmp); + AT_DISPATCH_ALL_TYPES(data_im.type(), "COL2IM cuda", ([&] { + COL2IM<<>>( + num_p * output.size(1), + out_col, + output.data(), + output.size(2), + output.size(3), + output.size(1)); + })); + cudaFree(data_col); + cudaFree(out_col); +} \ No newline at end of file diff --git a/pytorch/unoptimized/kernels/cuda/unoptimized_cuda.cpp b/pytorch/unoptimized/kernels/cuda/unoptimized_cuda.cpp new file mode 100644 index 0000000..bd041dc --- /dev/null +++ b/pytorch/unoptimized/kernels/cuda/unoptimized_cuda.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +// CUDA forward declarations + +void UNOPTIMIZED_LINEAR_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output); +void UNOPTIMIZED_CONV_GPU( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding); +// C++ interface + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +void UNOPTIMIZED_LINEAR( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output) +{ + + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + UNOPTIMIZED_LINEAR_GPU(input, weight, bias, output); +} + +void UNOPTIMIZED_CONV( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor bias, + torch::Tensor output, + torch::IntArrayRef strides, + torch::IntArrayRef padding) +{ + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + + UNOPTIMIZED_CONV_GPU(input, weight, bias, output,strides ,padding ); +} + +PYBIND11_MODULE(unoptimized_cuda, m) { + m.def("UNOPTIMIZED_LINEAR", &UNOPTIMIZED_LINEAR, "UNOPTIMIZED_LINEAR kernel(CUDA)"); + m.def("UNOPTIMIZED_CONV", &UNOPTIMIZED_CONV, "UNOPTIMIZED_CONV kernel(CUDA)"); +} diff --git a/pytorch/unoptimized/kernels/kernels.py b/pytorch/unoptimized/kernels/kernels.py new file mode 100644 index 0000000..4f57bce --- /dev/null +++ b/pytorch/unoptimized/kernels/kernels.py @@ -0,0 +1,34 @@ +import torch +try: + import unoptimized_cuda +except: + print("Unable to import CUDA unoptimized kernels") + +def linear(input, weight, bias): + out = torch.zeros([input.size(0), weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + if bias is not None: + unoptimized_cuda.UNOPTIMIZED_LINEAR(input, weight, bias, out) + else: + temp = torch.zeros([weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + unoptimized_cuda.UNOPTIMIZED_LINEAR(input, weight, temp, out) + + return out + +def conv2d(input, weight, bias, stride, padding): + if len(stride) == 1: + strides_h = stride[0] + strides_w = stride[0] + else: + strides_h = stride[0] + strides_w = stride[1] + out_height = int((input.size(2) - weight.size(2)) / strides_h +1) + out_width = int((input.size(3) - weight.size(3)) / strides_w +1) + out = torch.zeros([input.size(0), weight.size(0), out_height, out_width], dtype=torch.float, device=torch.device('cuda:0')) + + if bias is not None: + unoptimized_cuda.UNOPTIMIZED_CONV(input, weight, bias, out, stride, padding ) + else: + temp = torch.zeros([weight.size(0)], dtype=torch.float, device=torch.device('cuda:0')) + unoptimized_cuda.UNOPTIMIZED_CONV(input, weight, temp, out, stride, padding ) + + return out \ No newline at end of file diff --git a/pytorch/unoptimized/modules/conv.py b/pytorch/unoptimized/modules/conv.py new file mode 100644 index 0000000..863e9b2 --- /dev/null +++ b/pytorch/unoptimized/modules/conv.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.utils import _pair +from torch.nn import init +import unoptimized.kernels +import math +import numpy as np +import time + +class _UnoptimizedConvNd(nn.Module): + + __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode'] + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode): + super(_UnoptimizedConvNd, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + + if transposed: + self.weight = nn.Parameter(torch.Tensor( + in_channels, out_channels // groups, *kernel_size)) + else: + self.weight = nn.Parameter(torch.Tensor( + out_channels, in_channels // groups, *kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + return s.format(**self.__dict__) + +class UnoptimizedConv2d(_UnoptimizedConvNd): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros'): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(UnoptimizedConv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode) + + #@weak_script_method + def forward(self, input): + # start_time = time.time() + if self.padding_mode == 'circular': + print('circular') + if len(self.padding) == 2: + padding = (self.padding[0],self.padding[0],self.padding[1],self.padding[1]) + else: + padding = self.padding + input = F.pad(input = input, pad = padding, mode = 'constant', value = 0) + + return unoptimized.kernels.conv2d(input, self.weight, self.bias, self.stride, padding) diff --git a/pytorch/unoptimized/modules/linear.py b/pytorch/unoptimized/modules/linear.py new file mode 100644 index 0000000..41b8325 --- /dev/null +++ b/pytorch/unoptimized/modules/linear.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.utils import _pair +from torch.nn import init +import unoptimized.kernels +import math +import numpy as np +import time + + +class UnoptimizedLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True): + + super(UnoptimizedLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + # You should always register all possible parameters, but the + # optional ones can be None if you want. + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + # start_time = time.time() + return unoptimized.kernels.linear(input, self.weight, self.bias) + # end_time = time.time() + # print("Linear Time:", end_time - start_time ) + return out + + + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) diff --git a/requirements.txt b/requirements.txt index 2d8fd52..5d2ef12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,6 @@ -absl-py==0.7.1 -astor==0.7.1 -certifi==2019.3.9 -chardet==3.0.4 -future==0.17.1 -gast==0.2.2 -googleapis-common-protos==1.5.9 -grpcio==1.19.0 -h5py==2.9.0 -idna==2.8 -Markdown==3.1 -mock==2.0.0 numpy==1.16.2 opencv-python==4.1.0.25 -pbr==5.1.3 Pillow==6.0.0 -pkg-resources==0.0.0 -promise==2.2.1 -protobuf==3.7.1 -PyYAML==5.1 -requests==2.21.0 -scipy==1.2.1 -six==1.12.0 -tensorboard==1.13.1 termcolor==1.1.0 torch==1.1.0 torchvision==0.2.2.post3 -tqdm==4.31.1 -urllib3==1.24.1 -Werkzeug==0.15.2 -wrapt==1.11.1