From b25be5eec18667f6634742a68fb73f4a3c9a7105 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 15:06:17 +0800 Subject: [PATCH 01/17] Add hook for nn.Transformer --- thop/profile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thop/profile.py b/thop/profile.py index 3a2eda6..422b34d 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -67,6 +67,7 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, + nn.Transformer: count_Transformer } if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): From 03064d28c7b53a64df29d7c34bda07a1a9ed2331 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 15:09:06 +0800 Subject: [PATCH 02/17] Add a function to count macs of nn.Transformer count_Transformer for nn.Transformer --- thop/rnn_hooks.py | 83 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index c00fd47..cf591a7 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -196,3 +196,86 @@ def count_lstm(m: nn.LSTM, x, y): total_ops *= batch_size m.total_ops += torch.DoubleTensor([int(total_ops)]) + +def count_Transformer(m: nn.Transformer, x, y): + total_ops = 0 + src, tgt = x + if m.batch_first: + num_steps = src.shape[0] + target = tgt.shape[1] + sequence = src.shape[1] + embedding = src.shape[2] + else: + target = tgt.shape[0] + sequence = src.shape[0] + num_steps = src.shape[1] + embedding = src.shape[2] + num_head = m.nhead + encoder_layers = m.encoder.num_layers + decoder_layers = m.decoder.num_layers + # dim_forward(default = 2048) + forward = m.encoder.layers[0].linear1.out_features + total_ops = 0 + + def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): + if bool1 == 0: + # linear_q,linear_k,linear_v all N,S,E + total_multi = 3 * sequence * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += sequence ** 4 * embedding ** 2 + \ + sequence ** 2 + sequence * (3 * sequence - 1) + 1 + elif bool1 == 1: + # linear_q,linear_k,linear_v + total_multi = 3 * target * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += target ** 4 * embedding ** 2 + \ + target ** 2 + target * (3 * target-1) + 1 + elif bool1 == 2: + # linear_q,linear_k,linear_v + total_multi = embedding ** 2 * (2 * sequence + target) + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += target * (sequence ** 2) * embedding ** 3 + \ + target * sequence + target * (3 * sequence - 1)+1 + # number of heads and batchsize + total_multi *= total_multi * num_head*num_steps + # print(total_multi) + # concat + if bool1 == 0: + total_multi += num_steps * (sequence ** 2 * num_head * embedding) + # print(total_multi) + else: + total_multi += num_steps * (target ** 2 * num_head * embedding) + # output-> (N,S,E) or (S,N,E) + return total_multi + + def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): + total_en = 0 + total_en += MultiheadAttention(0, num_head, + num_steps, target, sequence, embedding) + # linear1 in_features= embedding, outfeatures= dim_forward + total_en += num_steps * sequence * forward * (embedding ** 2) + # linear2 + total_en += num_steps * sequence * embedding * (forward ** 2) + # norm1 norm2 + total_en += 2 * 2 * num_steps * embedding * sequence + # droup out 2,3 + return total_en + + def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): + total_de = 0 + total_de += MultiheadAttention(1, num_head, + num_steps, target, sequence, embedding) + total_de += MultiheadAttention(2, num_head, + num_steps, target, sequence, embedding) + # linear1 linear2 fft + total_de += num_steps * target * forward * (embedding ** 2) + total_de += num_steps * target * embedding * (forward ** 2) + # 3* norm + total_de += 3 * 2 * num_steps * embedding * target + return total_de + total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ + decoder_layers * \ + TransformerDecoderLayer(num_head, num_steps, + target, sequence, embedding) + m.total_ops += torch.DoubleTensor([int(total_ops)]) + From fd9f2ecea157f317312b37be8a35a3e28e59a631 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 16:05:31 +0800 Subject: [PATCH 03/17] add a example to test transformer The answer is 2.76945964418343e+25 44151808.0. --- benchmark/evaluate_transformer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 benchmark/evaluate_transformer.py diff --git a/benchmark/evaluate_transformer.py b/benchmark/evaluate_transformer.py new file mode 100644 index 0000000..f8c01a4 --- /dev/null +++ b/benchmark/evaluate_transformer.py @@ -0,0 +1,23 @@ + +import torch.nn as nn +from thop import profile +import torch +layers = [] +layers.append(nn.Transformer()) + +src=torch.rand((10, 32, 10)) +class model1(nn.Module): + def __init__(self): + super(model1, self).__init__() + self.linear1 = nn.Linear(10,512) + self.linear2 = nn.Linear(10,512) + self.transform = nn.Transformer() + def forward(self,input): + input1 = self.linear1(input) + input2 = self.linear2(input) + output = self.transform(input1,input2) + return output +model2 = nn.Sequential(model1()) +macs, params = profile(model2, inputs=(src, )) +print(macs,params) + From d33dbe24764d1504b5a84d8a3a35010cac7a0be6 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 16:43:11 +0800 Subject: [PATCH 04/17] some hooks and edit the softmax --- thop/vision/basic_hooks.py | 39 +++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index e3d7d7d..6fad224 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -55,6 +55,37 @@ def count_bn(m, x, y): m.total_ops += torch.DoubleTensor([int(total_ops)]) + +def count_ln(m, x, y): + x = x[0] + + nelements = x.numel() + if not m.training: + # same as count_bn + total_ops = 2 * nelements + + m.total_ops += torch.DoubleTensor([int(total_ops)]) + + +def count_in(m, x, y): + x = x[0] + + nelements = x.numel() + if not m.training: + # same as count_bn + total_ops = 2 * nelements + + m.total_ops += torch.DoubleTensor([int(total_ops)]) + + +def count_prelu(m, x, y): + x = x[0] + + nelements = x.numel() + if not m.training: + total_ops = nelements + + m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_relu(m, x, y): x = x[0] @@ -63,18 +94,16 @@ def count_relu(m, x, y): m.total_ops += torch.DoubleTensor([int(nelements)]) - + def count_softmax(m, x, y): x = x[0] batch_size, nfeatures = x.size() + nfeatures = x.size()[m.dim] + batch_size = x.numel()//nfeatures total_exp = nfeatures total_add = nfeatures - 1 - total_div = nfeatures - total_ops = batch_size * (total_exp + total_add + total_div) - - m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_avgpool(m, x, y): From d632a80b23bc9e112647724ad181cafddbeb80cb Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 16:47:55 +0800 Subject: [PATCH 05/17] a lot changes --- thop/profile.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/thop/profile.py b/thop/profile.py index 422b34d..9a47594 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -35,6 +35,12 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, + nn.LayerNorm: count_ln, + nn.InstanceNorm1d: count_in, + nn.InstanceNorm2d: count_in, + nn.InstanceNorm3d: count_in, + nn.PReLU: count_prelu, + nn.Softmax: count_softmax, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, From e1633e6168994ebd33e96a56c336caa30c5eb534 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Fri, 10 Sep 2021 16:54:36 +0800 Subject: [PATCH 06/17] edit softmax hook --- thop/vision/basic_hooks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index 6fad224..933c341 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -104,6 +104,9 @@ def count_softmax(m, x, y): total_exp = nfeatures total_add = nfeatures - 1 + total_div = nfeatures + total_ops = batch_size * (total_exp + total_add + total_div) + m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_avgpool(m, x, y): From f1a7805166c798fd2b8aac31f683ae6e63bdb2f4 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Sat, 11 Sep 2021 00:54:58 +0800 Subject: [PATCH 07/17] update nn.Transformer --- thop/rnn_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index cf591a7..b5f28f4 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -237,7 +237,7 @@ def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): total_multi += target * (sequence ** 2) * embedding ** 3 + \ target * sequence + target * (3 * sequence - 1)+1 # number of heads and batchsize - total_multi *= total_multi * num_head*num_steps + total_multi *= num_head*num_steps # print(total_multi) # concat if bool1 == 0: From 9194cd215a633b4fd8ee24232e1a41f2c4bb1406 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Sat, 18 Sep 2021 10:29:02 +0800 Subject: [PATCH 08/17] edit bugs in nn.Transformers no embeding ** 2 --- thop/rnn_hooks.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index b5f28f4..ee08fa8 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -220,21 +220,21 @@ def count_Transformer(m: nn.Transformer, x, y): def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): if bool1 == 0: # linear_q,linear_k,linear_v all N,S,E - total_multi = 3 * sequence * embedding ** 2 + total_multi = 3 * sequence ** 2 * embedding # self_attn softmax(Q*K_T/sqrt(dk))*V total_multi += sequence ** 4 * embedding ** 2 + \ sequence ** 2 + sequence * (3 * sequence - 1) + 1 elif bool1 == 1: # linear_q,linear_k,linear_v - total_multi = 3 * target * embedding ** 2 + total_multi = 3 * target ** 2 * embedding # self_attn softmax(Q*K_T/sqrt(dk))*V total_multi += target ** 4 * embedding ** 2 + \ target ** 2 + target * (3 * target-1) + 1 elif bool1 == 2: # linear_q,linear_k,linear_v - total_multi = embedding ** 2 * (2 * sequence + target) + total_multi = embedding * (2 * sequence ** 2 + target ** 2) # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += target * (sequence ** 2) * embedding ** 3 + \ + total_multi += target ** 2 * sequence ** 2 * embedding ** 2 + \ target * sequence + target * (3 * sequence - 1)+1 # number of heads and batchsize total_multi *= num_head*num_steps @@ -253,9 +253,9 @@ def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): total_en += MultiheadAttention(0, num_head, num_steps, target, sequence, embedding) # linear1 in_features= embedding, outfeatures= dim_forward - total_en += num_steps * sequence * forward * (embedding ** 2) + total_en += num_steps * sequence ** 2 * forward * embedding # linear2 - total_en += num_steps * sequence * embedding * (forward ** 2) + total_en += num_steps * sequence * embedding * forward ** 2 # norm1 norm2 total_en += 2 * 2 * num_steps * embedding * sequence # droup out 2,3 @@ -268,7 +268,7 @@ def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): total_de += MultiheadAttention(2, num_head, num_steps, target, sequence, embedding) # linear1 linear2 fft - total_de += num_steps * target * forward * (embedding ** 2) + total_de += num_steps * target ** 2 * forward * embedding total_de += num_steps * target * embedding * (forward ** 2) # 3* norm total_de += 3 * 2 * num_steps * embedding * target @@ -278,4 +278,3 @@ def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding) m.total_ops += torch.DoubleTensor([int(total_ops)]) - From f6fa0b03719fb5364b0ca1e947c80042483d7188 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Mon, 20 Sep 2021 11:37:03 +0800 Subject: [PATCH 09/17] change about examples For one word, about 25G flops and 44M parameters --- benchmark/evaluate_transformer.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/benchmark/evaluate_transformer.py b/benchmark/evaluate_transformer.py index f8c01a4..f0df780 100644 --- a/benchmark/evaluate_transformer.py +++ b/benchmark/evaluate_transformer.py @@ -2,22 +2,20 @@ import torch.nn as nn from thop import profile import torch -layers = [] -layers.append(nn.Transformer()) -src=torch.rand((10, 32, 10)) -class model1(nn.Module): +src=torch.rand((1, 1, 10))# S,N,x +class Model_transformer(nn.Module): def __init__(self): - super(model1, self).__init__() + super(Model_transformer, self).__init__() self.linear1 = nn.Linear(10,512) self.linear2 = nn.Linear(10,512) - self.transform = nn.Transformer() + self.transform = nn.Transformer(d_model = 512 ,nhead = 8, num_encoder_layers = 6) def forward(self,input): input1 = self.linear1(input) input2 = self.linear2(input) output = self.transform(input1,input2) return output -model2 = nn.Sequential(model1()) -macs, params = profile(model2, inputs=(src, )) +model = Model_transformer() +macs, params = profile(model, inputs=(src, )) print(macs,params) From dd9f41d1eeef9772031746e2bd9877a04868fb51 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:02:25 +0800 Subject: [PATCH 10/17] change about count_Transformer --- thop/rnn_hooks.py | 58 +++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index ee08fa8..0a9cb37 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -220,45 +220,45 @@ def count_Transformer(m: nn.Transformer, x, y): def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): if bool1 == 0: # linear_q,linear_k,linear_v all N,S,E - total_multi = 3 * sequence ** 2 * embedding + total_multi = 3 * sequence * embedding ** 2 # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += sequence ** 4 * embedding ** 2 + \ - sequence ** 2 + sequence * (3 * sequence - 1) + 1 + total_multi += (sequence ** 4 * (embedding/num_head) ** 2 + \ + sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head + #linear + total_multi += sequence * embedding ** 2 + #layernorm + total_multi += 2 * sequence * embedding elif bool1 == 1: # linear_q,linear_k,linear_v - total_multi = 3 * target ** 2 * embedding + total_multi = 3 * target * embedding ** 2 # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += target ** 4 * embedding ** 2 + \ - target ** 2 + target * (3 * target-1) + 1 + total_multi += (target ** 4 * (embedding/num_head) ** 2 + \ + target ** 2 + target * (3 * target-1) + 1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding elif bool1 == 2: # linear_q,linear_k,linear_v - total_multi = embedding * (2 * sequence ** 2 + target ** 2) + total_multi = embedding ** 2 * (2 * sequence + target) # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += target ** 2 * sequence ** 2 * embedding ** 2 + \ - target * sequence + target * (3 * sequence - 1)+1 + total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 + \ + target * sequence + target * (3 * sequence - 1)+1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding # number of heads and batchsize - total_multi *= num_head*num_steps - # print(total_multi) - # concat - if bool1 == 0: - total_multi += num_steps * (sequence ** 2 * num_head * embedding) - # print(total_multi) - else: - total_multi += num_steps * (target ** 2 * num_head * embedding) - # output-> (N,S,E) or (S,N,E) + total_multi *= num_steps return total_multi def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): total_en = 0 total_en += MultiheadAttention(0, num_head, num_steps, target, sequence, embedding) - # linear1 in_features= embedding, outfeatures= dim_forward - total_en += num_steps * sequence ** 2 * forward * embedding - # linear2 - total_en += num_steps * sequence * embedding * forward ** 2 - # norm1 norm2 - total_en += 2 * 2 * num_steps * embedding * sequence - # droup out 2,3 + print("multi",total_en) + # fed_forward(2 conv1d) + total_en += num_steps * sequence * forward * embedding + total_en += num_steps * sequence * embedding * forward + # norm1 + total_en += 2 * num_steps * embedding * sequence + print(total_en) return total_en def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): @@ -268,10 +268,10 @@ def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): total_de += MultiheadAttention(2, num_head, num_steps, target, sequence, embedding) # linear1 linear2 fft - total_de += num_steps * target ** 2 * forward * embedding - total_de += num_steps * target * embedding * (forward ** 2) - # 3* norm - total_de += 3 * 2 * num_steps * embedding * target + total_de += num_steps * target * forward * embedding + total_de += num_steps * target * embedding * forward + # layernorm + total_de += 2 * num_steps * embedding * target return total_de total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ decoder_layers * \ From 9d1d0d3b721c762a94596ef7417d07a790df3bd7 Mon Sep 17 00:00:00 2001 From: BetaCat <60107867+timmywanttolearn@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:44:49 +0800 Subject: [PATCH 11/17] delete print information --- thop/rnn_hooks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index 0a9cb37..2126c77 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -252,13 +252,11 @@ def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): total_en = 0 total_en += MultiheadAttention(0, num_head, num_steps, target, sequence, embedding) - print("multi",total_en) # fed_forward(2 conv1d) total_en += num_steps * sequence * forward * embedding total_en += num_steps * sequence * embedding * forward # norm1 total_en += 2 * num_steps * embedding * sequence - print(total_en) return total_en def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): @@ -278,3 +276,5 @@ def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding) m.total_ops += torch.DoubleTensor([int(total_ops)]) + + From 2034b862f3120fab40109d76341363ccfb0ee012 Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Tue, 5 Oct 2021 15:07:44 +0800 Subject: [PATCH 12/17] onnx fixed --- conv.onnx | Bin 0 -> 4926 bytes test.py | 13 +++--- test1.py | 8 ++++ thop/__init__.py | 2 +- thop/linear.onnx | Bin 0 -> 4625 bytes thop/onnx_profile.py | 77 +++++++++++++++++++++++++++++++ thop/vision/basic_hooks.py | 88 ++++++++++-------------------------- thop/vision/counter.py | 78 ++++++++++++++++++++++++++++++++ thop/vision/onnx_counter.py | 81 +++++++++++++++++++++++++++++++++ 9 files changed, 274 insertions(+), 73 deletions(-) create mode 100644 conv.onnx create mode 100644 test1.py create mode 100644 thop/linear.onnx create mode 100644 thop/onnx_profile.py create mode 100644 thop/vision/counter.py create mode 100644 thop/vision/onnx_counter.py diff --git a/conv.onnx b/conv.onnx new file mode 100644 index 0000000000000000000000000000000000000000..533b6fbf55764fff2bf2fa2627e6c9783558f5ed GIT binary patch literal 4926 zcmZ8l2UL{FvYsJ3u!x8*VnPrElq5mIbUA+GF^5Ig zHK6=mu(~K>LPWuU&WZtDL~vbmczE~TckX+qPo3)Ot`6tx3VljRL0wU>MHm|w6s9F> zZenY9K%b+-Q4Eg}Y!dpLb7aiT)!EisO3MR;%Qr>&o9QV0)8lA!RHDLTLIUEF4Q5*ZQ~6B6Yg9~K}8 z;mPo1zAH`O$O!_1<9V!q+V>$fj$*tpE<8BoKRa}^qB;|?xjEl4;a2d^yau#e?HPDPrc&Bop2i{S*EO-kt3N_Ci*e}}TnH#6F6MPz89 z46N`@fqAV0T+-zPVH=%M?zkJC)e+-WpL#S|>xqFHSt!n)NSjlq&tB z`6ko%%PguXD@L-l0yqN)Y51ZoXj33ZMPGavdZGxwEULz$>jkhwwgSU~M@ds~H*=;p z1y;Icf}qSDZ5K?%+3qjMYvUZO^RflA`65PND;##+5JJbbObk!~>$b`=3DsdigsM(vfdA!{b{ix`nI1XTe54B_xS-g$-bW$=*IW$G~ag*Gta9VpRUY?o(ld4p({Io2*s9OyM=O&;>%QrIZ zOg1^%{UQi=v!NnC2d-_oN9Pt7 z!dsOne4XeF!4|<-C^kg9Z5u(lxslmXOHi_CA{Zn)8HJ340hKrE(YZ{sS&H0l4}I-5WN%CmHakxS*Q2q(n)IBjDRf5O zv{V|AzX=-+U6Fb4fxLOX9y`}+!;GQnSR$^*=JIBuZLpKvc8vxPlW$D2WGUPxNjRUY zjT_xU;G>QVkV|XO<&RhNff)y^N|dnohct9f*2d^l61Z=-1h-y0K?iD6pu>6=y%d@W zqTMX&wx9;obe9uh|5UuuRR!O=s=O3+tkKfPw`gY|J4uxOYKD?PpNz3z4*qvHyiq4K11kp=C(6%VZkoUy8*2xH7u zfG3b+z78iL|J5pn<$i|<_Bm3QJKjV|@dXhj-C^$CRz$mohg4yg2+Kmo@bd5=!pD25 zG^GUD{0SH#I6*&Zj0ZQjAXKvP#c4G>9G)`~LJTuNH9Z#8Qe(->3&jvzBSN=CF<^HH zWLr2xgAE6l&;s0L5e>YLaWt%D2aT(l3+2(-Smsj>#qC?ryQ7(e?qf}?Qz_W2Y)Y0`9HUPw z*Mt5a%HZ*v0cm0F20xWr`qZifB}Z(iHphV!A5MneDl?pa#~3CKWkOU|H8u!-fcs%? zIGH_2ak(e5u+#3>9r<|o%|vw6_ros^DbV?l1>$%5Q23_=OZ=?RM*9)rJ?U(jwqhI> z>g*wvQFlqKZY_!qpC*$FbWtPsB+*oxMd#hyOOz56XyWE6v*QmGc=U2a6_x&HB~U`G-}&1mY5J z1jnbw!5*0(u}iUpDIFJqGkgrd>$VWi)QPasUl)2R>&e$uPH_585gb)%pxoH;P`b+m z@+O<3W5OZYm&_*}3fpK~yeD_-9#as{SVZ53lu}D}4tbj3g^>|s@$s`3&Hg+<*GThV z5ZWIphS_@#Z;#iKn@tAxFQo}g&NXT?{Kx0QcloVJH7Ax_TFFT%UcMK5^ z={2rzqz`$%Fa;xdX^?Tq5mHC9=)9_PBy6`Q*~Vz#(&P2enK>Q=A-QP(lLGpejKzRk z7rDImt@csnd&%q{jbQhua-7j^k4A4&uw$w=@paxpOd?%Cnr#BlY%Ag7jXYqzY#^+z zY-~7U34z6jXuFpw?%9`#DtKbG(Jzlq~t@Sb=eG*hHl`pTpGDU#`So*E)yO&TZhFeM2*vYs((FWY0n^&MbC9uaZrb<$z@j_l{u(&q3v1h=h+zRH?0 zTjP_c*m1~HokFL@UZtrwyl~CoYEZf;BpYQ7v0H5$c;`Q%fg5g9Z{Jy1C>TfY)>LAF z`814Lq6{tueYB~~54kZkMOz@ZH;eF`wo%m=nh>UQnOKjT0mTX#$hu-l zSq`(9;P^Tm6xz^tD&>%%91Cr3W`re$v&+_KgMp$UzLj5$?IE?e!S*V>-dqo*Uv7{) z{2FK;uMCfsR-@D52y}W_53IsoTKIrZyKXI^{8d3D?pX|$j|&F<>Fdz<%{cn|{#x>O ziWo=QQ>d(NG3e`?g5>60c(iRSFIREFW4|L|bzC{~-k3+ry7R%#yNTA{NC%_+Zt$^T zE$ZKPMa_*0#4`Fhm0dptgHPS2hoYR(x6Pc)uN?EcxmAoy>1N6qzW{5Ni3xt13#_f@ zsc4otbM2%tNLDFBuTud^tyU7vv;(B4|1?qaj>fHiV}DzNcxw7>9SYj*V6ba0nU{H+ zi1@uN8QQ+6Htzrr#To^dWjr&PXumJG3_ct#;oWuW81M7M{2iY zs(dchc^3vfn#YL9M1>rCn+kT?Zs>nb5xdSN(9x@P;GkFx>K7`Iuh&gHVs3CJCCQ?e zxhWc?chG2`T5QukLZy)d+;w$1DAEk1lWVO(GHEH@@GuiKF3p0$ZaHA#eLK0;$BDQ` z5zQ0`=|2A(RSdMmw1@&2<|u<-;c`^l7fF(ITgaI%5j06XNJq0WC=D#a!?w}5LBkz( zb=OkWSHZZcHxm^n%|w1o4+$8H12!xC(c@=DcspN$kAAGfuxS%vuyrR+os|mi2Wrr< zaS4cxsy)=q#}Wm5?3ieJYh&_XVM{ummPsYyzV{ z)6j86B(TV9=2X2i7KWKo8#fl}jpe24^@ZT;1Vr*=Hhp1~OQms^q%?6NwB6c97gnaj z$=q6Kugt<`7dx7G`3UKFv5-c#>!MFjEZiLoLRXf?!1r zwmq*V5&7lxD9Z^$5{khA_3%QVE3lFyxXb7w9QQaKy+0M<%RWbpX%s@{sXS0w`j}p7 z%Ye%{X;9Cr0>^;$5a3V;e5+w@b{2tU&n=+Pnv%>pDrgygg3cc;p`F6ZRQjrf!ly#q z(_9TG&cQHFJTccx15tM+rk^Q=1^Z$V!er6^Nh%({sezv@>gijpBycg?j1hh-!6z&R zI!>^G&D_(B8j z9`OXO%P8Tyjs0uK^%K=TdpIuoKntI{qmpJGsyE!C9=z8yBEkR`XcmH7V?8>DPC%V2 zE8$>Z9zCi|$>8C8q-cBvG%1!d9;%y(b@5*M$W#>uXB$!dE-NU^$))_%xlGc`2b7;v z&V23P4062Lc>Wm&?~m3~s9At4lWPkcBh`JJhkIoLz+8^`b{OWj29F0fT96gY(%_4$1nh<>EAr%!2w7egX;6a^v z(B$@k+sM_0KRe5SXXwfO7!ePPym(!bMTt6wSA zI3j~yYX|7Z>Yu5c-Pj#RpAog(IMTH(2q&p!B46JC4hjyEk!}y1t>%JnrsSY>nuyUl zafXhh%?8iuTky4J4O)H5gN^MYRC>yn^j$Ipu+2m31Itl*PMgcSznxj;Z->9XYM^En z`gF+h40X9#hFu|Fxo;xILttYue$`!%??zJLL$U~t4*f#1j--Ri`y4XgoexGi1vq-7 z9+ghrBFa}%foG?KZLONL*yIye_s=S)HOQU;fBL+q`7b}kiP;0#?9Q~sl`l_TvOQVCn%``xp9gAu0n@TWBHv!@FCp4Z{ z0kU7x(6v$v3kvhWaz;DVzA((KOP7V-j6C=Y(m?Lp40!B31LTG^F*08a_tO(_`RrP3 zHg&-E*T(qduN3&bzl&u5Fcym|2Dk=`RLR536(CC1ZCUr@e!>sqaeWu4VRNz?^_%s- p1upf;91SJ5x*A7Di9JR#|0Yi0n99HFKL1bUyV`dOa+Fb1{vRDO#ZmwO literal 0 HcmV?d00001 diff --git a/test.py b/test.py index 5e72bde..5f67409 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,7 @@ +from torchvision.models import resnet50 +from thop import profile import torch -import thop - -m = torch.nn.Conv2d(128, 128, 1) -x = torch.randn(1, 128, 16, 16) - -flops = thop.profile(m, inputs=(x,), verbose=True) -print(flops) +model = resnet50() +input = torch.randn(1, 3, 224, 224) +macs, params = profile(model, inputs=(input, )) +print("result",macs,params) \ No newline at end of file diff --git a/test1.py b/test1.py new file mode 100644 index 0000000..39bf19c --- /dev/null +++ b/test1.py @@ -0,0 +1,8 @@ +import onnx +from thop import onnx_profile +from onnx import numpy_helper +import numpy as np +model = onnx.load("conv.onnx") +#print(onnx.helper.printable_graph(model.graph)) +onnx_profile = onnx_profile() +print(onnx_profile.calculate_macs(model)) diff --git a/thop/__init__.py b/thop/__init__.py index 3362b4c..89d7a62 100644 --- a/thop/__init__.py +++ b/thop/__init__.py @@ -1,5 +1,5 @@ from .utils import clever_format from .profile import profile, profile_origin - +from .onnx_profile import onnx_profile import torch default_dtype = torch.float64 \ No newline at end of file diff --git a/thop/linear.onnx b/thop/linear.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d0a0e05b65f3f81f4cc961d84dc5346886212a9b GIT binary patch literal 4625 zcmZ`-X;ck**FVjJq`9Qh98pdosq^0|Nh6UIl2ej&j~oqGiJazn(sY^+lB7#WQs=+r z%5Y^!rqILDB}qs(NiO$kJ@2#DyWS7)hrRaxu=oD5e|z|CS!op+Zd~-9$n9a;5>}SW z?WT<64di8ZMQ~%Hw^_-H*{X=yXv=y9MSI0?wlNK*|HqOyl~>^Gir5hpX=NF_D=11{ z%tl3KPfYZ5kd?NSV@Sw0D?^ECrK0qo%ogFhqAhmBa`!|=n;(~u4N($v`aj|B1rpJ= zP2{3^A(p6SV%wE1biH{p3M%=m*TEAcHgi3R(Ax$Ze%K>_z77@`of2mB^Dwa46gE%k zf#r=Usx>H&?3K@iqh_m7d>;$Ajk07sA{Yn9KGHdS1KeG`9^F68gj1%um@yPe^|^ZV z!JkIRV_Mmdm)JvI>lWZeGz<0OJi%$aoZd`zqUM%bz!i(68gF!fCvRcTSQOHsPIM4o~4adI3681_@|+wsWo`;DI6y#z%bgqk0ZXT#>((ePezCkD(HWYF zUxPZK*J}hA)OcGL36MNSFNli4;nW>_rh*e`@0WzqyK^9*`83IR zBnR#NcWI9qk8FRH3l>4T_`5^^Xbc2F)gd{IpI3=ScV?i+#&U>Vv4nK}rT~JE&xD?H zmP4g^4rUl0B?dWp!0ek#ZhBrLb-Ds-HaH7MbDj}?Q!=VWj}fJ*cdUXZY+PBF1{r}> zi0Vo(SmXqqy&kMJ@AiO5sZD5I><_$^iX9I9%P9YkEbLto0ac;(H zOS6iA*S(h|d!+>RSH;i;lQk&3`WMO3O2S}?0cy0c3P=IO+P)17Qu6FC?7yU&o8Cz+UGaGZp# zi^k3i<3e8G5mG7@4E#wQ7(3TNFZf1beam{}x_1gct_9FC+eiWgNqGA~2bnm25dD(^ zNy3sRv??Zz6shKdT2?t}H(mzMFCB)d?sAMWO+{{hFYzwVChpl>;01pWa#z%1@CG?t zqghOc=6g_{X^ZfoSr+z74v@~;Jk}EbDm)@r1G)#T@wQhUohbiK|NLH!&!410!ygr- z%W*N_yg*R-)JErzq~nK=Mc@<~4zapoI5fM7WGp*D*J+l5ARz&jCN?2g^D!~y#9>#Z z95^SsAR}d%xUlwNh?fys9yy5pdZ$U6j54Uq-wv)C`NVt2OIF^Q{ot=4pm(3dL7Vp= z`Mp62nk?LK*fAe(WvxQ}W6LqVFaVifZji;)A|@^V7pCFrAIzAV}TDQD(a$o(IU(cmmsGKQ*may zBi?f=M#e-WDX5)Avefd>k(2^Q(h2iY?O@JQ8}!##qD)c0(9}E(e)&5CB==pRBGrxf zx?_asXU%|H3pYV{(>kni-h_;)LwGJXo6alsqph9C=t1XnG;lA4(EIMF^>~2pSXlw` z^($cHQx`#%r*ug1IH7BK$ZiZLuGbprnf*LqEm48CaRKR?_?zgpa9}K44DKyT1)m=R zh#fjhWBs&g*VPQ7>8y%^H#b;plQ3aeVk+u2P0*024$^*h7W6&efI7-1V6(KAmSZaV zRv*RJuf>2>m94%H9_V|d72Qk z5+r67!=lu5d{mYL;q7S{IBbDk6+fumk!D&YC`LCqDXh4bg24(wAgk|-1?r~o^L;G5 zYzshDb0vIvYCPdlKl`ze z`_D_d-Z}!5s{UZzz3+zA-hNmY@*5F3L?As@hO%q=h+S$pd@-@X;NmR+75C`P`y7lm zSHr@66*zj%7E4u5Qs2|_v4_aeMYq+^drqHFd!;YdsrcdT$r<>ST>urw^&z6x0j1Bb z0419tGTLnh{U%-HgyIprrCW`5U;V&*TtF@M>!Zkel1!*^;BJZz9JY!9Q#Em1#&<-W zPw^C%zx6KYBE^@*D{v^>iG@}_We#PhkbByrbfHW=_`xohf zL;gnE7?X+Pxz)&%4I{k{c!nb{=q;=Duf5=SKM`EWDvXycLE}hsNKiaV z4Oc6GWJD`fdiSUB*1x9l^5PQ}glo`c=X0Q8nKp?gk4TSA2<7%%6^0l#lFp_A{4QSs zXY;DCqQM>+f-sV7P=kl-4?^2X3;K^#HC*cO#}cE%*gCowj(6r_lqCc8Lp~857e5>& zV`PKWHNsuGjI@u10*~`n7%j0F;`cZrba+ALY(Db#y=QF?<6>U3KZLGj;OOWI>Kc_z zmQ8qI!`5?Tp3f#M*!704y_^Osq9ajavo;Q1+>VDAFyZN&O!%54gG0~GlH|7<$Zokr zGs>TmZ|d9dfwKwTD$K@V``xg)<}_^^VG^y?Oj2o?gDRzN=+wahuDTv6iOvR#8~M^`-L(H&`c*ATT3zay*k8JY^C#_#i8zrGANkcEp+&G zE6e0$6>vA4CcO5eDbv^FEoC;5bNLV#G%FVaQnnK5p4!m+bc^f ztSiS&Z?cg!;f3QHKND{Edm;0&HSk*5B+{h-1y>A+scb$KY*L=y2d3oWQ&-&mSOB7H zo|M5nV;d@O4AP(oktIo}y;BM7T`IuCJ`Wh_9`vT}E^7MuC4I5o4THY~;pL-6kk#~% z+L~#B6Gt8|OlHE%=~du zBO;pb%94%@fknM3F!^2s&AW^sIL8!tZhUgUDH&%LF)(!YLP(7)gznXRklQ4Hsi0Xf zRvZA$SC1g~7abb@S_XN8@%As(=ir9?)bMc^jRhZ^o#$`Bv6+tkei(U`e#Pe(rRQa{`0H5|dt-i@>) zq!L83nL@{!T1+u0Beq5dsYUV9>Hexk4&M}-J;sDgXEx@_WMSvG0u13TK%qr9)f<@y zvFAn9=G9DW9bv%!0S(BQ{e{FH7bj;|o1yabEnUBI2RKAHQT8!ATJ2o}3#zJ6XvRUU ze^g=AU^!j?VISDkKcI~dRiQb&Q7F^LgOHO~Y1cPTWcghp&oUjr`phAakS_xLK?m|U zCIj=cE5O@ahL#1D!l8SK7@~ci+`3}`4{~d<=jU=Tt=@;qD+@qll>&0+@`-4`kEm!% z;@|5olW0ydZmLUx-X&W>G|lxk^f}_WCQTUft0uQB(=nrSH%*M+iyY;J&{ObK7(Klk z@<$HCbt`XBoAsU+=zkzbSNMZ<{aSpwP#X1J?viPmi2gk}H0i`f=y@JPDguIVhvZBQ z3s(d0pV^e*(0i$2d<9W6X(Zi!mEg^*V|99-A_J=uK(^6`UR5So^|pZ4Wlj>1_XhS< z2EA?lnz%2Hp2o*Ix>G|A4@+zSe)cUQoA;9D7iPmx({ak!p$(>nOQESr0-j&?LI>_! q;pYGS*{Ep9tI3L~D9MYlc+U375n*YZo#7SIf#=ikHUTB5@ literal 0 HcmV?d00001 diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py new file mode 100644 index 0000000..2fe5cfe --- /dev/null +++ b/thop/onnx_profile.py @@ -0,0 +1,77 @@ +import torch +import torch.nn +import onnx +from onnx import numpy_helper +import numpy as np +from thop.vision.counter import * +from thop.vision.onnx_counter import * + +class onnx_profile(): + def __init__(self) -> None: + pass + def calculate_params(self,model: onnx.ModelProto): + onnx_weights = model.graph.initializer + params = 0 + + for onnx_w in onnx_weights: + try: + weight = numpy_helper.to_array(onnx_w) + params += np.prod(weight.shape) + except Exception as _: + pass + + return params + + def create_dic(self,weight, input , output): + diction = {} + for w in weight: + dim = np.array(w.dims) + diction[str(w.name)] = dim + if (dim.size == 1): + diction[str(w.name)] = np.append(1, dim) + for i in input: + # print(i.type.tensor_type.shape.dim[0].dim_value) + dim = np.array(i.type.tensor_type.shape.dim[0].dim_value) + # print(i.type.tensor_type.shape.dim.__sizeof__()) + #name2dims[str(i.name)] = [dim] + dim = [] + for key in i.type.tensor_type.shape.dim: + dim = np.append(dim, int(key.dim_value)) + # print(key.dim_value) + # print(dim) + diction[str(i.name)] = dim + if(dim.size == 1): + diction[str(i.name)] = np.append(1, dim) + for o in output: + dim = np.array(o.type.tensor_type.shape.dim[0].dim_value) + diction[str(o.name)] = [dim] + if(dim.size == 1): + diction[str(o.name)] = np.append(1, dim) + return diction + def nodes_counter(self, diction, node): + if node.op_type not in onnx_operators: + print("Sorry, we haven't add ",node.op_type,"into dictionary.") + return + else: + fn = onnx_operators[node.op_type] + return fn(diction,node) + + + + + def calculate_macs(self,model: onnx.ModelProto) -> torch.DoubleTensor: + macs = 0 + name2dims = {} + weight = model.graph.initializer + nodes = model.graph.node + input = model.graph.input + output = model.graph.output + name2dims = self.create_dic(weight,input,output) + macs = 0 + for n in nodes: + macs_adding, out_size,outname = self.nodes_counter(name2dims, n) + name2dims[outname] = out_size + macs += macs_adding + return np.array(macs[0]) + + diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index 933c341..0d0b669 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -1,6 +1,6 @@ import argparse import logging - +from .counter import * import torch import torch.nn as nn from torch.nn.modules.conv import _ConvNd @@ -12,11 +12,11 @@ def count_parameters(m, x, y): total_params = 0 for p in m.parameters(): total_params += torch.DoubleTensor([p.numel()]) - m.total_params[0] = total_params + m.total_params[0] = counter_parameters(m.parameters()) def zero_ops(m, x, y): - m.total_ops += torch.DoubleTensor([int(0)]) + m.total_ops += counter_zero_ops() def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): @@ -36,46 +36,31 @@ def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): # N x H x W (exclude Cout) output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() - # Cout x Cin x Kw x Kh - kernel_ops = m.weight.nelement() - if m.bias is not None: - # Cout x 1 - kernel_ops += + m.bias.nelement() - # x N x H x W x Cout x (Cin x Kw x Kh + bias) - m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) - + # # Cout x Cin x Kw x Kh + # kernel_ops = m.weight.nelement() + # if m.bias is not None: + # # Cout x 1 + # kernel_ops += + m.bias.nelement() + # # x N x H x W x Cout x (Cin x Kw x Kh + bias) + # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) + m.total_ops += counter_conv(m.bias.nelement(), m.weight.nelement(), output_size) def count_bn(m, x, y): x = x[0] - - nelements = x.numel() if not m.training: - # subtract, divide, gamma, beta - total_ops = 2 * nelements - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_norm(x.numel()) def count_ln(m, x, y): x = x[0] - - nelements = x.numel() if not m.training: - # same as count_bn - total_ops = 2 * nelements - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_norm(x.numel()) def count_in(m, x, y): x = x[0] - - nelements = x.numel() if not m.training: - # same as count_bn - total_ops = 2 * nelements - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_norm(x.numel()) def count_prelu(m, x, y): @@ -83,16 +68,15 @@ def count_prelu(m, x, y): nelements = x.numel() if not m.training: - total_ops = nelements + m.total_ops += counter_relu(nelements) - m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_relu(m, x, y): x = x[0] nelements = x.numel() - m.total_ops += torch.DoubleTensor([int(nelements)]) + m.total_ops += counter_relu(nelements) def count_softmax(m, x, y): @@ -102,62 +86,37 @@ def count_softmax(m, x, y): nfeatures = x.size()[m.dim] batch_size = x.numel()//nfeatures - total_exp = nfeatures - total_add = nfeatures - 1 - total_div = nfeatures - total_ops = batch_size * (total_exp + total_add + total_div) - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_softmax(batch_size, nfeatures) def count_avgpool(m, x, y): # total_add = torch.prod(torch.Tensor([m.kernel_size])) # total_div = 1 # kernel_ops = total_add + total_div - kernel_ops = 1 num_elements = y.numel() - total_ops = kernel_ops * num_elements - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_avgpool(num_elements) def count_adap_avgpool(m, x, y): kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) total_add = torch.prod(kernel) - total_div = 1 - kernel_ops = total_add + total_div num_elements = y.numel() - total_ops = kernel_ops * num_elements - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_adap_avg(total_add, num_elements) # TODO: verify the accuracy def count_upsample(m, x, y): if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) - return zero_ops(m, x, y) + return counter_zero_ops() if m.mode == "nearest": - return zero_ops(m, x, y) + return counter_zero_ops() x = x[0] - if m.mode == "linear": - total_ops = y.nelement() * 5 # 2 muls + 3 add - elif m.mode == "bilinear": - # https://en.wikipedia.org/wiki/Bilinear_interpolation - total_ops = y.nelement() * 11 # 6 muls + 5 adds - elif m.mode == "bicubic": - # https://en.wikipedia.org/wiki/Bicubic_interpolation - # Product matrix [4x4] x [4x4] x [4x4] - ops_solve_A = 224 # 128 muls + 96 adds - ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds - total_ops = y.nelement() * (ops_solve_A + ops_solve_p) - elif m.mode == "trilinear": - # https://en.wikipedia.org/wiki/Trilinear_interpolation - # can viewed as 2 bilinear + 1 linear - total_ops = y.nelement() * (13 * 2 + 5) - - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_upsample(m.mode, y.nelement()) # nn.Linear @@ -167,6 +126,5 @@ def count_linear(m, x, y): # total_add = m.in_features - 1 # total_add += 1 if m.bias is not None else 0 num_elements = y.numel() - total_ops = total_mul * num_elements - m.total_ops += torch.DoubleTensor([int(total_ops)]) + m.total_ops += counter_linear(total_mul, num_elements) diff --git a/thop/vision/counter.py b/thop/vision/counter.py new file mode 100644 index 0000000..c985bfd --- /dev/null +++ b/thop/vision/counter.py @@ -0,0 +1,78 @@ +import torch +import numpy as np + + +def counter_parameters(para_list): + total_params = 0 + for p in para_list: + total_params += torch.DoubleTensor([p.nelement()]) + return total_params + +def counter_zero_ops(): + return torch.DoubleTensor([int(0)]) + +def counter_conv(bias, kernel_size, output_size): + """inputs are all numbers!""" + kernel_ops = 0 + kernel_ops = kernel_size + if bias is not None: + kernel_ops += bias + return torch.DoubleTensor([int(output_size * kernel_ops)]) + +def counter_norm(input_size): + """input is a number not a array or tensor""" + return torch.DoubleTensor([2 * input_size]) + +def counter_relu(input_size: torch.Tensor): + return torch.DoubleTensor([int(input_size)]) + +def counter_softmax(batch_size, nfeatures): + total_exp = nfeatures + total_add = nfeatures - 1 + total_div = nfeatures + total_ops = batch_size * (total_exp + total_add + total_div) + return torch.DoubleTensor([int(total_ops)]) + +def counter_avgpool(input_size): + return torch.DoubleTensor([int(input_size)]) + +def counter_adap_avg(kernel_size, output_size): + total_div = 1 + kernel_op = kernel_size + total_div + return torch.DoubleTensor([int(kernel_op * output_size)]) + +def counter_upsample(mode: str, output_size): + total_ops = output_size + if mode == "linear": + total_ops *= 5 + elif mode == "bilinear": + total_ops *= 11 + elif mode == "bicubic": + ops_solve_A = 224 # 128 muls + 96 adds + ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds + total_ops *= (ops_solve_A + ops_solve_p) + elif mode == "trilinear": + total_ops *= (13 * 2 + 5) + return torch.DoubleTensor([int(total_ops)]) + +def counter_linear(in_feature, num_elements): + return torch.DoubleTensor([int(in_feature * num_elements)]) +def counter_onnx_MatMul(diction,node): + input1 = node.input[0] + input2 = node.input[0] + input1_dim = diction[input1] + input2_dim = diction[input2] + if (input1_dim.size >= input2_dim.size): + out_size = np.append(input1_dim[0:-1], input2_dim[-1]) + else: + out_size = np.append(input2_dim[0:-1], input1_dim[-1]) + input1_dim = np.array(input1_dim) + input2_dim = np.array(input2_dim) + macs = np.prod(input1_dim)/input1_dim[-1]*np.prod(input2_dim) + output_name = diction[node.output[0]] + return macs, out_size, output_name +#def count_onnx_ +def counter_MatMul(input_size, output_size): + input_size = np.array(input_size) + output_size = np.array(output_size) + return np.prod(np.append(input_size[0:-1],output_size[-1])) diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py new file mode 100644 index 0000000..fe9c846 --- /dev/null +++ b/thop/vision/onnx_counter.py @@ -0,0 +1,81 @@ +import torch +import numpy as np + +from thop.vision.basic_hooks import zero_ops +from .counter import * + + +def onnx_counter_MatMul(diction, node): + input1 = node.input[0] + input2 = node.input[1] + input1_dim = diction[input1] + input2_dim = diction[input2] + out_size = np.append(input1_dim[0:-1], input2_dim[-1]) + output_name = node.output[0] + macs = counter_MatMul(input1_dim, out_size) + return macs, out_size, output_name + + +def onnx_counter_Add(diction, node): + out_size = diction[node.input[1]] + output_name = node.output[0] + macs = counter_zero_ops() + return macs, out_size, output_name + + +def onnx_counter_Conv(diction, node): + #print(node) + # bias,kernelsize,outputsize + for i in node.input: + if('bias' in i): + dim_bias = diction[i] + if('weight' in i): + dim_weight = diction[i] # cout, cin,kw,kh + # print(dim_weight,dim_bias) + for attr in node.attribute: + # print(attr) + if(attr.name == 'kernel_shape'): + dim_kernel = attr.ints # kw,kh + if(attr.name == 'strides'): + dim_stride = attr.ints + if(attr.name == 'pads'): + dim_pad = attr.ints + if(attr.name == 'dilations'): + dim_dil = attr.ints + # print(dim_dil) + dim_input = diction[node.input[0]] + output_size = np.append( + dim_input[0:-np.array(dim_kernel).size-1], dim_weight[0]) + hw = dim_input[-np.array(dim_kernel).size:] + for i in range(hw.size): + hw[i] = int((hw[i]+2*dim_pad[i]-dim_dil[i] * + (dim_kernel[i]-1))/dim_stride[i]) + output_size = np.append(output_size,hw) + #print(output_size) + #print(np.prod(dim_bias), np.prod(dim_kernel), np.prod(output_size)) + macs = counter_conv(np.prod(dim_bias), np.prod(dim_kernel), np.prod(output_size)) + output_name = node.output[0] + return macs, output_size, output_name + +def onnx_counter_Constant(diction,node): + #print(node) + macs = counter_zero_ops() + output_name = node.output[0] + output_size = [1] + print(macs, output_size, output_name) + return macs, output_size, output_name + +def onnx_counter_Mul(diction, node): + print(node) + + pass + + +onnx_operators = { + 'MatMul': onnx_counter_MatMul, + 'Add': onnx_counter_Add, + 'Conv': onnx_counter_Conv, + 'Mul' : onnx_counter_Mul, + 'Constant' : onnx_counter_Constant, + None: None, +} From 9492d969ecd098eedb5ba110ba60492f487f6abc Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Tue, 5 Oct 2021 17:03:34 +0800 Subject: [PATCH 13/17] onnx_basic_fixed --- conv.onnx | Bin 4926 -> 0 bytes test1.py => test1_onnx.py | 2 +- thop/linear.onnx | Bin 4625 -> 0 bytes thop/onnx_profile.py | 22 +++--- thop/vision/basic_hooks.py | 16 +++-- thop/vision/counter.py | 44 +++++++----- thop/vision/onnx_counter.py | 137 ++++++++++++++++++++++++++++++++---- 7 files changed, 170 insertions(+), 51 deletions(-) delete mode 100644 conv.onnx rename test1.py => test1_onnx.py (79%) delete mode 100644 thop/linear.onnx diff --git a/conv.onnx b/conv.onnx deleted file mode 100644 index 533b6fbf55764fff2bf2fa2627e6c9783558f5ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4926 zcmZ8l2UL{FvYsJ3u!x8*VnPrElq5mIbUA+GF^5Ig zHK6=mu(~K>LPWuU&WZtDL~vbmczE~TckX+qPo3)Ot`6tx3VljRL0wU>MHm|w6s9F> zZenY9K%b+-Q4Eg}Y!dpLb7aiT)!EisO3MR;%Qr>&o9QV0)8lA!RHDLTLIUEF4Q5*ZQ~6B6Yg9~K}8 z;mPo1zAH`O$O!_1<9V!q+V>$fj$*tpE<8BoKRa}^qB;|?xjEl4;a2d^yau#e?HPDPrc&Bop2i{S*EO-kt3N_Ci*e}}TnH#6F6MPz89 z46N`@fqAV0T+-zPVH=%M?zkJC)e+-WpL#S|>xqFHSt!n)NSjlq&tB z`6ko%%PguXD@L-l0yqN)Y51ZoXj33ZMPGavdZGxwEULz$>jkhwwgSU~M@ds~H*=;p z1y;Icf}qSDZ5K?%+3qjMYvUZO^RflA`65PND;##+5JJbbObk!~>$b`=3DsdigsM(vfdA!{b{ix`nI1XTe54B_xS-g$-bW$=*IW$G~ag*Gta9VpRUY?o(ld4p({Io2*s9OyM=O&;>%QrIZ zOg1^%{UQi=v!NnC2d-_oN9Pt7 z!dsOne4XeF!4|<-C^kg9Z5u(lxslmXOHi_CA{Zn)8HJ340hKrE(YZ{sS&H0l4}I-5WN%CmHakxS*Q2q(n)IBjDRf5O zv{V|AzX=-+U6Fb4fxLOX9y`}+!;GQnSR$^*=JIBuZLpKvc8vxPlW$D2WGUPxNjRUY zjT_xU;G>QVkV|XO<&RhNff)y^N|dnohct9f*2d^l61Z=-1h-y0K?iD6pu>6=y%d@W zqTMX&wx9;obe9uh|5UuuRR!O=s=O3+tkKfPw`gY|J4uxOYKD?PpNz3z4*qvHyiq4K11kp=C(6%VZkoUy8*2xH7u zfG3b+z78iL|J5pn<$i|<_Bm3QJKjV|@dXhj-C^$CRz$mohg4yg2+Kmo@bd5=!pD25 zG^GUD{0SH#I6*&Zj0ZQjAXKvP#c4G>9G)`~LJTuNH9Z#8Qe(->3&jvzBSN=CF<^HH zWLr2xgAE6l&;s0L5e>YLaWt%D2aT(l3+2(-Smsj>#qC?ryQ7(e?qf}?Qz_W2Y)Y0`9HUPw z*Mt5a%HZ*v0cm0F20xWr`qZifB}Z(iHphV!A5MneDl?pa#~3CKWkOU|H8u!-fcs%? zIGH_2ak(e5u+#3>9r<|o%|vw6_ros^DbV?l1>$%5Q23_=OZ=?RM*9)rJ?U(jwqhI> z>g*wvQFlqKZY_!qpC*$FbWtPsB+*oxMd#hyOOz56XyWE6v*QmGc=U2a6_x&HB~U`G-}&1mY5J z1jnbw!5*0(u}iUpDIFJqGkgrd>$VWi)QPasUl)2R>&e$uPH_585gb)%pxoH;P`b+m z@+O<3W5OZYm&_*}3fpK~yeD_-9#as{SVZ53lu}D}4tbj3g^>|s@$s`3&Hg+<*GThV z5ZWIphS_@#Z;#iKn@tAxFQo}g&NXT?{Kx0QcloVJH7Ax_TFFT%UcMK5^ z={2rzqz`$%Fa;xdX^?Tq5mHC9=)9_PBy6`Q*~Vz#(&P2enK>Q=A-QP(lLGpejKzRk z7rDImt@csnd&%q{jbQhua-7j^k4A4&uw$w=@paxpOd?%Cnr#BlY%Ag7jXYqzY#^+z zY-~7U34z6jXuFpw?%9`#DtKbG(Jzlq~t@Sb=eG*hHl`pTpGDU#`So*E)yO&TZhFeM2*vYs((FWY0n^&MbC9uaZrb<$z@j_l{u(&q3v1h=h+zRH?0 zTjP_c*m1~HokFL@UZtrwyl~CoYEZf;BpYQ7v0H5$c;`Q%fg5g9Z{Jy1C>TfY)>LAF z`814Lq6{tueYB~~54kZkMOz@ZH;eF`wo%m=nh>UQnOKjT0mTX#$hu-l zSq`(9;P^Tm6xz^tD&>%%91Cr3W`re$v&+_KgMp$UzLj5$?IE?e!S*V>-dqo*Uv7{) z{2FK;uMCfsR-@D52y}W_53IsoTKIrZyKXI^{8d3D?pX|$j|&F<>Fdz<%{cn|{#x>O ziWo=QQ>d(NG3e`?g5>60c(iRSFIREFW4|L|bzC{~-k3+ry7R%#yNTA{NC%_+Zt$^T zE$ZKPMa_*0#4`Fhm0dptgHPS2hoYR(x6Pc)uN?EcxmAoy>1N6qzW{5Ni3xt13#_f@ zsc4otbM2%tNLDFBuTud^tyU7vv;(B4|1?qaj>fHiV}DzNcxw7>9SYj*V6ba0nU{H+ zi1@uN8QQ+6Htzrr#To^dWjr&PXumJG3_ct#;oWuW81M7M{2iY zs(dchc^3vfn#YL9M1>rCn+kT?Zs>nb5xdSN(9x@P;GkFx>K7`Iuh&gHVs3CJCCQ?e zxhWc?chG2`T5QukLZy)d+;w$1DAEk1lWVO(GHEH@@GuiKF3p0$ZaHA#eLK0;$BDQ` z5zQ0`=|2A(RSdMmw1@&2<|u<-;c`^l7fF(ITgaI%5j06XNJq0WC=D#a!?w}5LBkz( zb=OkWSHZZcHxm^n%|w1o4+$8H12!xC(c@=DcspN$kAAGfuxS%vuyrR+os|mi2Wrr< zaS4cxsy)=q#}Wm5?3ieJYh&_XVM{ummPsYyzV{ z)6j86B(TV9=2X2i7KWKo8#fl}jpe24^@ZT;1Vr*=Hhp1~OQms^q%?6NwB6c97gnaj z$=q6Kugt<`7dx7G`3UKFv5-c#>!MFjEZiLoLRXf?!1r zwmq*V5&7lxD9Z^$5{khA_3%QVE3lFyxXb7w9QQaKy+0M<%RWbpX%s@{sXS0w`j}p7 z%Ye%{X;9Cr0>^;$5a3V;e5+w@b{2tU&n=+Pnv%>pDrgygg3cc;p`F6ZRQjrf!ly#q z(_9TG&cQHFJTccx15tM+rk^Q=1^Z$V!er6^Nh%({sezv@>gijpBycg?j1hh-!6z&R zI!>^G&D_(B8j z9`OXO%P8Tyjs0uK^%K=TdpIuoKntI{qmpJGsyE!C9=z8yBEkR`XcmH7V?8>DPC%V2 zE8$>Z9zCi|$>8C8q-cBvG%1!d9;%y(b@5*M$W#>uXB$!dE-NU^$))_%xlGc`2b7;v z&V23P4062Lc>Wm&?~m3~s9At4lWPkcBh`JJhkIoLz+8^`b{OWj29F0fT96gY(%_4$1nh<>EAr%!2w7egX;6a^v z(B$@k+sM_0KRe5SXXwfO7!ePPym(!bMTt6wSA zI3j~yYX|7Z>Yu5c-Pj#RpAog(IMTH(2q&p!B46JC4hjyEk!}y1t>%JnrsSY>nuyUl zafXhh%?8iuTky4J4O)H5gN^MYRC>yn^j$Ipu+2m31Itl*PMgcSznxj;Z->9XYM^En z`gF+h40X9#hFu|Fxo;xILttYue$`!%??zJLL$U~t4*f#1j--Ri`y4XgoexGi1vq-7 z9+ghrBFa}%foG?KZLONL*yIye_s=S)HOQU;fBL+q`7b}kiP;0#?9Q~sl`l_TvOQVCn%``xp9gAu0n@TWBHv!@FCp4Z{ z0kU7x(6v$v3kvhWaz;DVzA((KOP7V-j6C=Y(m?Lp40!B31LTG^F*08a_tO(_`RrP3 zHg&-E*T(qduN3&bzl&u5Fcym|2Dk=`RLR536(CC1ZCUr@e!>sqaeWu4VRNz?^_%s- p1upf;91SJ5x*A7Di9JR#|0Yi0n99HFKL1bUyV`dOa+Fb1{vRDO#ZmwO diff --git a/test1.py b/test1_onnx.py similarity index 79% rename from test1.py rename to test1_onnx.py index 39bf19c..dd78c20 100644 --- a/test1.py +++ b/test1_onnx.py @@ -3,6 +3,6 @@ from onnx import numpy_helper import numpy as np model = onnx.load("conv.onnx") -#print(onnx.helper.printable_graph(model.graph)) +print(onnx.helper.printable_graph(model.graph)) onnx_profile = onnx_profile() print(onnx_profile.calculate_macs(model)) diff --git a/thop/linear.onnx b/thop/linear.onnx deleted file mode 100644 index d0a0e05b65f3f81f4cc961d84dc5346886212a9b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4625 zcmZ`-X;ck**FVjJq`9Qh98pdosq^0|Nh6UIl2ej&j~oqGiJazn(sY^+lB7#WQs=+r z%5Y^!rqILDB}qs(NiO$kJ@2#DyWS7)hrRaxu=oD5e|z|CS!op+Zd~-9$n9a;5>}SW z?WT<64di8ZMQ~%Hw^_-H*{X=yXv=y9MSI0?wlNK*|HqOyl~>^Gir5hpX=NF_D=11{ z%tl3KPfYZ5kd?NSV@Sw0D?^ECrK0qo%ogFhqAhmBa`!|=n;(~u4N($v`aj|B1rpJ= zP2{3^A(p6SV%wE1biH{p3M%=m*TEAcHgi3R(Ax$Ze%K>_z77@`of2mB^Dwa46gE%k zf#r=Usx>H&?3K@iqh_m7d>;$Ajk07sA{Yn9KGHdS1KeG`9^F68gj1%um@yPe^|^ZV z!JkIRV_Mmdm)JvI>lWZeGz<0OJi%$aoZd`zqUM%bz!i(68gF!fCvRcTSQOHsPIM4o~4adI3681_@|+wsWo`;DI6y#z%bgqk0ZXT#>((ePezCkD(HWYF zUxPZK*J}hA)OcGL36MNSFNli4;nW>_rh*e`@0WzqyK^9*`83IR zBnR#NcWI9qk8FRH3l>4T_`5^^Xbc2F)gd{IpI3=ScV?i+#&U>Vv4nK}rT~JE&xD?H zmP4g^4rUl0B?dWp!0ek#ZhBrLb-Ds-HaH7MbDj}?Q!=VWj}fJ*cdUXZY+PBF1{r}> zi0Vo(SmXqqy&kMJ@AiO5sZD5I><_$^iX9I9%P9YkEbLto0ac;(H zOS6iA*S(h|d!+>RSH;i;lQk&3`WMO3O2S}?0cy0c3P=IO+P)17Qu6FC?7yU&o8Cz+UGaGZp# zi^k3i<3e8G5mG7@4E#wQ7(3TNFZf1beam{}x_1gct_9FC+eiWgNqGA~2bnm25dD(^ zNy3sRv??Zz6shKdT2?t}H(mzMFCB)d?sAMWO+{{hFYzwVChpl>;01pWa#z%1@CG?t zqghOc=6g_{X^ZfoSr+z74v@~;Jk}EbDm)@r1G)#T@wQhUohbiK|NLH!&!410!ygr- z%W*N_yg*R-)JErzq~nK=Mc@<~4zapoI5fM7WGp*D*J+l5ARz&jCN?2g^D!~y#9>#Z z95^SsAR}d%xUlwNh?fys9yy5pdZ$U6j54Uq-wv)C`NVt2OIF^Q{ot=4pm(3dL7Vp= z`Mp62nk?LK*fAe(WvxQ}W6LqVFaVifZji;)A|@^V7pCFrAIzAV}TDQD(a$o(IU(cmmsGKQ*may zBi?f=M#e-WDX5)Avefd>k(2^Q(h2iY?O@JQ8}!##qD)c0(9}E(e)&5CB==pRBGrxf zx?_asXU%|H3pYV{(>kni-h_;)LwGJXo6alsqph9C=t1XnG;lA4(EIMF^>~2pSXlw` z^($cHQx`#%r*ug1IH7BK$ZiZLuGbprnf*LqEm48CaRKR?_?zgpa9}K44DKyT1)m=R zh#fjhWBs&g*VPQ7>8y%^H#b;plQ3aeVk+u2P0*024$^*h7W6&efI7-1V6(KAmSZaV zRv*RJuf>2>m94%H9_V|d72Qk z5+r67!=lu5d{mYL;q7S{IBbDk6+fumk!D&YC`LCqDXh4bg24(wAgk|-1?r~o^L;G5 zYzshDb0vIvYCPdlKl`ze z`_D_d-Z}!5s{UZzz3+zA-hNmY@*5F3L?As@hO%q=h+S$pd@-@X;NmR+75C`P`y7lm zSHr@66*zj%7E4u5Qs2|_v4_aeMYq+^drqHFd!;YdsrcdT$r<>ST>urw^&z6x0j1Bb z0419tGTLnh{U%-HgyIprrCW`5U;V&*TtF@M>!Zkel1!*^;BJZz9JY!9Q#Em1#&<-W zPw^C%zx6KYBE^@*D{v^>iG@}_We#PhkbByrbfHW=_`xohf zL;gnE7?X+Pxz)&%4I{k{c!nb{=q;=Duf5=SKM`EWDvXycLE}hsNKiaV z4Oc6GWJD`fdiSUB*1x9l^5PQ}glo`c=X0Q8nKp?gk4TSA2<7%%6^0l#lFp_A{4QSs zXY;DCqQM>+f-sV7P=kl-4?^2X3;K^#HC*cO#}cE%*gCowj(6r_lqCc8Lp~857e5>& zV`PKWHNsuGjI@u10*~`n7%j0F;`cZrba+ALY(Db#y=QF?<6>U3KZLGj;OOWI>Kc_z zmQ8qI!`5?Tp3f#M*!704y_^Osq9ajavo;Q1+>VDAFyZN&O!%54gG0~GlH|7<$Zokr zGs>TmZ|d9dfwKwTD$K@V``xg)<}_^^VG^y?Oj2o?gDRzN=+wahuDTv6iOvR#8~M^`-L(H&`c*ATT3zay*k8JY^C#_#i8zrGANkcEp+&G zE6e0$6>vA4CcO5eDbv^FEoC;5bNLV#G%FVaQnnK5p4!m+bc^f ztSiS&Z?cg!;f3QHKND{Edm;0&HSk*5B+{h-1y>A+scb$KY*L=y2d3oWQ&-&mSOB7H zo|M5nV;d@O4AP(oktIo}y;BM7T`IuCJ`Wh_9`vT}E^7MuC4I5o4THY~;pL-6kk#~% z+L~#B6Gt8|OlHE%=~du zBO;pb%94%@fknM3F!^2s&AW^sIL8!tZhUgUDH&%LF)(!YLP(7)gznXRklQ4Hsi0Xf zRvZA$SC1g~7abb@S_XN8@%As(=ir9?)bMc^jRhZ^o#$`Bv6+tkei(U`e#Pe(rRQa{`0H5|dt-i@>) zq!L83nL@{!T1+u0Beq5dsYUV9>Hexk4&M}-J;sDgXEx@_WMSvG0u13TK%qr9)f<@y zvFAn9=G9DW9bv%!0S(BQ{e{FH7bj;|o1yabEnUBI2RKAHQT8!ATJ2o}3#zJ6XvRUU ze^g=AU^!j?VISDkKcI~dRiQb&Q7F^LgOHO~Y1cPTWcghp&oUjr`phAakS_xLK?m|U zCIj=cE5O@ahL#1D!l8SK7@~ci+`3}`4{~d<=jU=Tt=@;qD+@qll>&0+@`-4`kEm!% z;@|5olW0ydZmLUx-X&W>G|lxk^f}_WCQTUft0uQB(=nrSH%*M+iyY;J&{ObK7(Klk z@<$HCbt`XBoAsU+=zkzbSNMZ<{aSpwP#X1J?viPmi2gk}H0i`f=y@JPDguIVhvZBQ z3s(d0pV^e*(0i$2d<9W6X(Zi!mEg^*V|99-A_J=uK(^6`UR5So^|pZ4Wlj>1_XhS< z2EA?lnz%2Hp2o*Ix>G|A4@+zSe)cUQoA;9D7iPmx({ak!p$(>nOQESr0-j&?LI>_! q;pYGS*{Ep9tI3L~D9MYlc+U375n*YZo#7SIf#=ikHUTB5@ diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py index 2fe5cfe..b6d702c 100644 --- a/thop/onnx_profile.py +++ b/thop/onnx_profile.py @@ -6,10 +6,12 @@ from thop.vision.counter import * from thop.vision.onnx_counter import * + class onnx_profile(): def __init__(self) -> None: pass - def calculate_params(self,model: onnx.ModelProto): + + def calculate_params(self, model: onnx.ModelProto): onnx_weights = model.graph.initializer params = 0 @@ -22,7 +24,7 @@ def calculate_params(self,model: onnx.ModelProto): return params - def create_dic(self,weight, input , output): + def create_dic(self, weight, input, output): diction = {} for w in weight: dim = np.array(w.dims) @@ -48,30 +50,26 @@ def create_dic(self,weight, input , output): if(dim.size == 1): diction[str(o.name)] = np.append(1, dim) return diction + def nodes_counter(self, diction, node): if node.op_type not in onnx_operators: - print("Sorry, we haven't add ",node.op_type,"into dictionary.") + print("Sorry, we haven't add ", node.op_type, "into dictionary.") return else: fn = onnx_operators[node.op_type] - return fn(diction,node) - - - + return fn(diction, node) - def calculate_macs(self,model: onnx.ModelProto) -> torch.DoubleTensor: + def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor: macs = 0 name2dims = {} weight = model.graph.initializer nodes = model.graph.node input = model.graph.input output = model.graph.output - name2dims = self.create_dic(weight,input,output) + name2dims = self.create_dic(weight, input, output) macs = 0 for n in nodes: - macs_adding, out_size,outname = self.nodes_counter(name2dims, n) + macs_adding, out_size, outname = self.nodes_counter(name2dims, n) name2dims[outname] = out_size macs += macs_adding return np.array(macs[0]) - - diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index 0d0b669..ba859ad 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -43,14 +43,16 @@ def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): # kernel_ops += + m.bias.nelement() # # x N x H x W x Cout x (Cin x Kw x Kh + bias) # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) - m.total_ops += counter_conv(m.bias.nelement(), m.weight.nelement(), output_size) + m.total_ops += counter_conv(m.bias.nelement(), + m.weight.nelement(), output_size) + def count_bn(m, x, y): x = x[0] if not m.training: m.total_ops += counter_norm(x.numel()) - + def count_ln(m, x, y): x = x[0] if not m.training: @@ -78,11 +80,9 @@ def count_relu(m, x, y): m.total_ops += counter_relu(nelements) - + def count_softmax(m, x, y): x = x[0] - - batch_size, nfeatures = x.size() nfeatures = x.size()[m.dim] batch_size = x.numel()//nfeatures @@ -99,7 +99,8 @@ def count_avgpool(m, x, y): def count_adap_avgpool(m, x, y): - kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) + kernel = torch.DoubleTensor( + [*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) total_add = torch.prod(kernel) num_elements = y.numel() @@ -109,7 +110,8 @@ def count_adap_avgpool(m, x, y): # TODO: verify the accuracy def count_upsample(m, x, y): if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" - logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) + logging.warning( + "mode %s is not implemented yet, take it a zero op" % m.mode) return counter_zero_ops() if m.mode == "nearest": diff --git a/thop/vision/counter.py b/thop/vision/counter.py index c985bfd..0975619 100644 --- a/thop/vision/counter.py +++ b/thop/vision/counter.py @@ -8,9 +8,11 @@ def counter_parameters(para_list): total_params += torch.DoubleTensor([p.nelement()]) return total_params + def counter_zero_ops(): return torch.DoubleTensor([int(0)]) + def counter_conv(bias, kernel_size, output_size): """inputs are all numbers!""" kernel_ops = 0 @@ -19,13 +21,16 @@ def counter_conv(bias, kernel_size, output_size): kernel_ops += bias return torch.DoubleTensor([int(output_size * kernel_ops)]) + def counter_norm(input_size): """input is a number not a array or tensor""" return torch.DoubleTensor([2 * input_size]) + def counter_relu(input_size: torch.Tensor): return torch.DoubleTensor([int(input_size)]) + def counter_softmax(batch_size, nfeatures): total_exp = nfeatures total_add = nfeatures - 1 @@ -33,14 +38,17 @@ def counter_softmax(batch_size, nfeatures): total_ops = batch_size * (total_exp + total_add + total_div) return torch.DoubleTensor([int(total_ops)]) + def counter_avgpool(input_size): return torch.DoubleTensor([int(input_size)]) + def counter_adap_avg(kernel_size, output_size): total_div = 1 kernel_op = kernel_size + total_div return torch.DoubleTensor([int(kernel_op * output_size)]) + def counter_upsample(mode: str, output_size): total_ops = output_size if mode == "linear": @@ -55,24 +63,28 @@ def counter_upsample(mode: str, output_size): total_ops *= (13 * 2 + 5) return torch.DoubleTensor([int(total_ops)]) + def counter_linear(in_feature, num_elements): return torch.DoubleTensor([int(in_feature * num_elements)]) -def counter_onnx_MatMul(diction,node): - input1 = node.input[0] - input2 = node.input[0] - input1_dim = diction[input1] - input2_dim = diction[input2] - if (input1_dim.size >= input2_dim.size): - out_size = np.append(input1_dim[0:-1], input2_dim[-1]) - else: - out_size = np.append(input2_dim[0:-1], input1_dim[-1]) - input1_dim = np.array(input1_dim) - input2_dim = np.array(input2_dim) - macs = np.prod(input1_dim)/input1_dim[-1]*np.prod(input2_dim) - output_name = diction[node.output[0]] - return macs, out_size, output_name -#def count_onnx_ + + def counter_MatMul(input_size, output_size): input_size = np.array(input_size) output_size = np.array(output_size) - return np.prod(np.append(input_size[0:-1],output_size[-1])) + return np.prod(input_size) * output_size[-1] + + +def counter_Mul(input_size): + return input_size + + +def counter_pow(input_size): + return input_size + + +def counter_sqrt(input_size): + return input_size + + +def counter_div(input_size): + return input_size diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py index fe9c846..5a5d3ab 100644 --- a/thop/vision/onnx_counter.py +++ b/thop/vision/onnx_counter.py @@ -12,19 +12,22 @@ def onnx_counter_MatMul(diction, node): input2_dim = diction[input2] out_size = np.append(input1_dim[0:-1], input2_dim[-1]) output_name = node.output[0] - macs = counter_MatMul(input1_dim, out_size) + macs = counter_MatMul(input1_dim, out_size[-2:]) return macs, out_size, output_name def onnx_counter_Add(diction, node): - out_size = diction[node.input[1]] + if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: + out_size = diction[node.input[1]] + else: + out_size = diction[node.input[0]] output_name = node.output[0] macs = counter_zero_ops() return macs, out_size, output_name def onnx_counter_Conv(diction, node): - #print(node) + # print(node) # bias,kernelsize,outputsize for i in node.input: if('bias' in i): @@ -49,33 +52,137 @@ def onnx_counter_Conv(diction, node): hw = dim_input[-np.array(dim_kernel).size:] for i in range(hw.size): hw[i] = int((hw[i]+2*dim_pad[i]-dim_dil[i] * - (dim_kernel[i]-1))/dim_stride[i]) - output_size = np.append(output_size,hw) - #print(output_size) + (dim_kernel[i]-1))/dim_stride[i]) + output_size = np.append(output_size, hw) + # print(output_size) #print(np.prod(dim_bias), np.prod(dim_kernel), np.prod(output_size)) - macs = counter_conv(np.prod(dim_bias), np.prod(dim_kernel), np.prod(output_size)) + macs = counter_conv(np.prod(dim_bias), np.prod( + dim_kernel), np.prod(output_size)) output_name = node.output[0] return macs, output_size, output_name -def onnx_counter_Constant(diction,node): - #print(node) + +def onnx_counter_Constant(diction, node): + # print(node) macs = counter_zero_ops() output_name = node.output[0] output_size = [1] - print(macs, output_size, output_name) + #print(macs, output_size, output_name) return macs, output_size, output_name + def onnx_counter_Mul(diction, node): - print(node) - - pass + if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: + input_size = diction[node.input[1]] + else: + input_size = diction[node.input[0]] + macs = counter_Mul(np.prod(input_size)) + output_size = diction[node.input[0]] + output_name = node.output[0] + return macs, output_size, output_name + + +def onnx_counter_bn(diction, node): + input_size = diction[node.input[0]] + macs = counter_norm(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + return macs, output_size, output_name + + +def onnx_counter_relu(diction, node): + input_size = diction[node.input[0]] + macs = counter_relu(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + return macs, output_size, output_name + + +def onnx_counter_reducemean(diction, node): + input_size = diction[node.input[0]] + macs = counter_zero_ops() + output_name = node.output[0] + output_size = input_size + #print("reduce",macs, output_size, output_name) + return macs, output_size, output_name + + +def onnx_counter_sub(diction, node): + input_size = diction[node.input[0]] + macs = counter_zero_ops() + output_name = node.output[0] + output_size = input_size + #print("sub",macs, output_size, output_name) + return macs, output_size, output_name + + +def onnx_counter_pow(diction, node): + if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: + input_size = diction[node.input[1]] + else: + input_size = diction[node.input[0]] + macs = counter_pow(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + #print("pow",macs, output_size, output_name) + return macs, output_size, output_name + + +def onnx_counter_sqrt(diction, node): + input_size = diction[node.input[0]] + macs = counter_sqrt(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + #print("sqrt",macs, output_size, output_name) + return macs, output_size, output_name + + +def onnx_counter_div(diction, node): + if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: + input_size = diction[node.input[1]] + else: + input_size = diction[node.input[0]] + macs = counter_div(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + #print("div",macs, output_size, output_name) + return macs, output_size, output_name + + +def onnx_counter_instance(diction, node): + input_size = diction[node.input[0]] + macs = counter_norm(np.prod(input_size)) + output_name = node.output[0] + output_size = input_size + return macs, output_size, output_name + + +def onnx_counter_softmax(diction, node): + input_size = diction[node.input[0]] + dim = node.attribute[0].i + nfeatures = input_size[dim] + batch_size = np.prod(input_size) / nfeatures + macs = counter_softmax(nfeatures, batch_size) + output_name = node.output[0] + output_size = input_size + #print("soft",macs, output_size, output_name) + return macs, output_size, output_name onnx_operators = { 'MatMul': onnx_counter_MatMul, 'Add': onnx_counter_Add, 'Conv': onnx_counter_Conv, - 'Mul' : onnx_counter_Mul, - 'Constant' : onnx_counter_Constant, + 'Mul': onnx_counter_Mul, + 'Constant': onnx_counter_Constant, + 'BatchNormalization': onnx_counter_bn, + 'Relu': onnx_counter_relu, + 'ReduceMean': onnx_counter_reducemean, + 'Sub': onnx_counter_sub, + 'Pow': onnx_counter_pow, + 'Sqrt': onnx_counter_sqrt, + 'Div': onnx_counter_div, + 'InstanceNormalization': onnx_counter_instance, + 'Softmax': onnx_counter_softmax, None: None, } From d0b58ca91213b7b5a9dff2dc051c92268a6744d0 Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Tue, 5 Oct 2021 20:00:35 +0800 Subject: [PATCH 14/17] reset to lg_opcounter --- .DS_Store | Bin 0 -> 6148 bytes benchmark/evaluate_transformer.py | 21 ---- test.py | 13 ++- test1_onnx.py | 8 -- thop/__init__.py | 2 +- thop/onnx_profile.py | 75 ------------ thop/profile.py | 9 +- thop/rnn_hooks.py | 82 ------------- thop/vision/basic_hooks.py | 98 +++++++++------- thop/vision/counter.py | 90 -------------- thop/vision/onnx_counter.py | 188 ------------------------------ 11 files changed, 62 insertions(+), 524 deletions(-) create mode 100644 .DS_Store delete mode 100644 benchmark/evaluate_transformer.py delete mode 100644 test1_onnx.py delete mode 100644 thop/onnx_profile.py delete mode 100644 thop/vision/counter.py delete mode 100644 thop/vision/onnx_counter.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 None: - pass - - def calculate_params(self, model: onnx.ModelProto): - onnx_weights = model.graph.initializer - params = 0 - - for onnx_w in onnx_weights: - try: - weight = numpy_helper.to_array(onnx_w) - params += np.prod(weight.shape) - except Exception as _: - pass - - return params - - def create_dic(self, weight, input, output): - diction = {} - for w in weight: - dim = np.array(w.dims) - diction[str(w.name)] = dim - if (dim.size == 1): - diction[str(w.name)] = np.append(1, dim) - for i in input: - # print(i.type.tensor_type.shape.dim[0].dim_value) - dim = np.array(i.type.tensor_type.shape.dim[0].dim_value) - # print(i.type.tensor_type.shape.dim.__sizeof__()) - #name2dims[str(i.name)] = [dim] - dim = [] - for key in i.type.tensor_type.shape.dim: - dim = np.append(dim, int(key.dim_value)) - # print(key.dim_value) - # print(dim) - diction[str(i.name)] = dim - if(dim.size == 1): - diction[str(i.name)] = np.append(1, dim) - for o in output: - dim = np.array(o.type.tensor_type.shape.dim[0].dim_value) - diction[str(o.name)] = [dim] - if(dim.size == 1): - diction[str(o.name)] = np.append(1, dim) - return diction - - def nodes_counter(self, diction, node): - if node.op_type not in onnx_operators: - print("Sorry, we haven't add ", node.op_type, "into dictionary.") - return - else: - fn = onnx_operators[node.op_type] - return fn(diction, node) - - def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor: - macs = 0 - name2dims = {} - weight = model.graph.initializer - nodes = model.graph.node - input = model.graph.input - output = model.graph.output - name2dims = self.create_dic(weight, input, output) - macs = 0 - for n in nodes: - macs_adding, out_size, outname = self.nodes_counter(name2dims, n) - name2dims[outname] = out_size - macs += macs_adding - return np.array(macs[0]) diff --git a/thop/profile.py b/thop/profile.py index 642bc8d..656a869 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -35,12 +35,6 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, - nn.LayerNorm: count_ln, - nn.InstanceNorm1d: count_in, - nn.InstanceNorm2d: count_in, - nn.InstanceNorm3d: count_in, - nn.PReLU: count_prelu, - nn.Softmax: count_softmax, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, @@ -73,9 +67,8 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, - nn.Transformer: count_Transformer, - nn.Sequential: zero_ops, + nn.Sequential: zero_ops, } if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index 2126c77..c00fd47 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -196,85 +196,3 @@ def count_lstm(m: nn.LSTM, x, y): total_ops *= batch_size m.total_ops += torch.DoubleTensor([int(total_ops)]) - -def count_Transformer(m: nn.Transformer, x, y): - total_ops = 0 - src, tgt = x - if m.batch_first: - num_steps = src.shape[0] - target = tgt.shape[1] - sequence = src.shape[1] - embedding = src.shape[2] - else: - target = tgt.shape[0] - sequence = src.shape[0] - num_steps = src.shape[1] - embedding = src.shape[2] - num_head = m.nhead - encoder_layers = m.encoder.num_layers - decoder_layers = m.decoder.num_layers - # dim_forward(default = 2048) - forward = m.encoder.layers[0].linear1.out_features - total_ops = 0 - - def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): - if bool1 == 0: - # linear_q,linear_k,linear_v all N,S,E - total_multi = 3 * sequence * embedding ** 2 - # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += (sequence ** 4 * (embedding/num_head) ** 2 + \ - sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head - #linear - total_multi += sequence * embedding ** 2 - #layernorm - total_multi += 2 * sequence * embedding - elif bool1 == 1: - # linear_q,linear_k,linear_v - total_multi = 3 * target * embedding ** 2 - # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += (target ** 4 * (embedding/num_head) ** 2 + \ - target ** 2 + target * (3 * target-1) + 1) * num_head - total_multi += target * embedding ** 2 - total_multi += 2 * target * embedding - elif bool1 == 2: - # linear_q,linear_k,linear_v - total_multi = embedding ** 2 * (2 * sequence + target) - # self_attn softmax(Q*K_T/sqrt(dk))*V - total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 + \ - target * sequence + target * (3 * sequence - 1)+1) * num_head - total_multi += target * embedding ** 2 - total_multi += 2 * target * embedding - # number of heads and batchsize - total_multi *= num_steps - return total_multi - - def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): - total_en = 0 - total_en += MultiheadAttention(0, num_head, - num_steps, target, sequence, embedding) - # fed_forward(2 conv1d) - total_en += num_steps * sequence * forward * embedding - total_en += num_steps * sequence * embedding * forward - # norm1 - total_en += 2 * num_steps * embedding * sequence - return total_en - - def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): - total_de = 0 - total_de += MultiheadAttention(1, num_head, - num_steps, target, sequence, embedding) - total_de += MultiheadAttention(2, num_head, - num_steps, target, sequence, embedding) - # linear1 linear2 fft - total_de += num_steps * target * forward * embedding - total_de += num_steps * target * embedding * forward - # layernorm - total_de += 2 * num_steps * embedding * target - return total_de - total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ - decoder_layers * \ - TransformerDecoderLayer(num_head, num_steps, - target, sequence, embedding) - m.total_ops += torch.DoubleTensor([int(total_ops)]) - - diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index ba859ad..e3d7d7d 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -1,6 +1,6 @@ import argparse import logging -from .counter import * + import torch import torch.nn as nn from torch.nn.modules.conv import _ConvNd @@ -12,11 +12,11 @@ def count_parameters(m, x, y): total_params = 0 for p in m.parameters(): total_params += torch.DoubleTensor([p.numel()]) - m.total_params[0] = counter_parameters(m.parameters()) + m.total_params[0] = total_params def zero_ops(m, x, y): - m.total_ops += counter_zero_ops() + m.total_ops += torch.DoubleTensor([int(0)]) def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): @@ -36,41 +36,24 @@ def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): # N x H x W (exclude Cout) output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() - # # Cout x Cin x Kw x Kh - # kernel_ops = m.weight.nelement() - # if m.bias is not None: - # # Cout x 1 - # kernel_ops += + m.bias.nelement() - # # x N x H x W x Cout x (Cin x Kw x Kh + bias) - # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) - m.total_ops += counter_conv(m.bias.nelement(), - m.weight.nelement(), output_size) + # Cout x Cin x Kw x Kh + kernel_ops = m.weight.nelement() + if m.bias is not None: + # Cout x 1 + kernel_ops += + m.bias.nelement() + # x N x H x W x Cout x (Cin x Kw x Kh + bias) + m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) def count_bn(m, x, y): x = x[0] - if not m.training: - m.total_ops += counter_norm(x.numel()) - - -def count_ln(m, x, y): - x = x[0] - if not m.training: - m.total_ops += counter_norm(x.numel()) - - -def count_in(m, x, y): - x = x[0] - if not m.training: - m.total_ops += counter_norm(x.numel()) - - -def count_prelu(m, x, y): - x = x[0] nelements = x.numel() if not m.training: - m.total_ops += counter_relu(nelements) + # subtract, divide, gamma, beta + total_ops = 2 * nelements + + m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_relu(m, x, y): @@ -78,47 +61,71 @@ def count_relu(m, x, y): nelements = x.numel() - m.total_ops += counter_relu(nelements) + m.total_ops += torch.DoubleTensor([int(nelements)]) def count_softmax(m, x, y): x = x[0] - nfeatures = x.size()[m.dim] - batch_size = x.numel()//nfeatures - m.total_ops += counter_softmax(batch_size, nfeatures) + batch_size, nfeatures = x.size() + + total_exp = nfeatures + total_add = nfeatures - 1 + total_div = nfeatures + total_ops = batch_size * (total_exp + total_add + total_div) + + m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_avgpool(m, x, y): # total_add = torch.prod(torch.Tensor([m.kernel_size])) # total_div = 1 # kernel_ops = total_add + total_div + kernel_ops = 1 num_elements = y.numel() + total_ops = kernel_ops * num_elements - m.total_ops += counter_avgpool(num_elements) + m.total_ops += torch.DoubleTensor([int(total_ops)]) def count_adap_avgpool(m, x, y): - kernel = torch.DoubleTensor( - [*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) + kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor([*(y.shape[2:])]) total_add = torch.prod(kernel) + total_div = 1 + kernel_ops = total_add + total_div num_elements = y.numel() + total_ops = kernel_ops * num_elements - m.total_ops += counter_adap_avg(total_add, num_elements) + m.total_ops += torch.DoubleTensor([int(total_ops)]) # TODO: verify the accuracy def count_upsample(m, x, y): if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" - logging.warning( - "mode %s is not implemented yet, take it a zero op" % m.mode) - return counter_zero_ops() + logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) + return zero_ops(m, x, y) if m.mode == "nearest": - return counter_zero_ops() + return zero_ops(m, x, y) x = x[0] - m.total_ops += counter_upsample(m.mode, y.nelement()) + if m.mode == "linear": + total_ops = y.nelement() * 5 # 2 muls + 3 add + elif m.mode == "bilinear": + # https://en.wikipedia.org/wiki/Bilinear_interpolation + total_ops = y.nelement() * 11 # 6 muls + 5 adds + elif m.mode == "bicubic": + # https://en.wikipedia.org/wiki/Bicubic_interpolation + # Product matrix [4x4] x [4x4] x [4x4] + ops_solve_A = 224 # 128 muls + 96 adds + ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds + total_ops = y.nelement() * (ops_solve_A + ops_solve_p) + elif m.mode == "trilinear": + # https://en.wikipedia.org/wiki/Trilinear_interpolation + # can viewed as 2 bilinear + 1 linear + total_ops = y.nelement() * (13 * 2 + 5) + + m.total_ops += torch.DoubleTensor([int(total_ops)]) # nn.Linear @@ -128,5 +135,6 @@ def count_linear(m, x, y): # total_add = m.in_features - 1 # total_add += 1 if m.bias is not None else 0 num_elements = y.numel() + total_ops = total_mul * num_elements - m.total_ops += counter_linear(total_mul, num_elements) + m.total_ops += torch.DoubleTensor([int(total_ops)]) diff --git a/thop/vision/counter.py b/thop/vision/counter.py deleted file mode 100644 index 0975619..0000000 --- a/thop/vision/counter.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -import numpy as np - - -def counter_parameters(para_list): - total_params = 0 - for p in para_list: - total_params += torch.DoubleTensor([p.nelement()]) - return total_params - - -def counter_zero_ops(): - return torch.DoubleTensor([int(0)]) - - -def counter_conv(bias, kernel_size, output_size): - """inputs are all numbers!""" - kernel_ops = 0 - kernel_ops = kernel_size - if bias is not None: - kernel_ops += bias - return torch.DoubleTensor([int(output_size * kernel_ops)]) - - -def counter_norm(input_size): - """input is a number not a array or tensor""" - return torch.DoubleTensor([2 * input_size]) - - -def counter_relu(input_size: torch.Tensor): - return torch.DoubleTensor([int(input_size)]) - - -def counter_softmax(batch_size, nfeatures): - total_exp = nfeatures - total_add = nfeatures - 1 - total_div = nfeatures - total_ops = batch_size * (total_exp + total_add + total_div) - return torch.DoubleTensor([int(total_ops)]) - - -def counter_avgpool(input_size): - return torch.DoubleTensor([int(input_size)]) - - -def counter_adap_avg(kernel_size, output_size): - total_div = 1 - kernel_op = kernel_size + total_div - return torch.DoubleTensor([int(kernel_op * output_size)]) - - -def counter_upsample(mode: str, output_size): - total_ops = output_size - if mode == "linear": - total_ops *= 5 - elif mode == "bilinear": - total_ops *= 11 - elif mode == "bicubic": - ops_solve_A = 224 # 128 muls + 96 adds - ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds - total_ops *= (ops_solve_A + ops_solve_p) - elif mode == "trilinear": - total_ops *= (13 * 2 + 5) - return torch.DoubleTensor([int(total_ops)]) - - -def counter_linear(in_feature, num_elements): - return torch.DoubleTensor([int(in_feature * num_elements)]) - - -def counter_MatMul(input_size, output_size): - input_size = np.array(input_size) - output_size = np.array(output_size) - return np.prod(input_size) * output_size[-1] - - -def counter_Mul(input_size): - return input_size - - -def counter_pow(input_size): - return input_size - - -def counter_sqrt(input_size): - return input_size - - -def counter_div(input_size): - return input_size diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py deleted file mode 100644 index 5a5d3ab..0000000 --- a/thop/vision/onnx_counter.py +++ /dev/null @@ -1,188 +0,0 @@ -import torch -import numpy as np - -from thop.vision.basic_hooks import zero_ops -from .counter import * - - -def onnx_counter_MatMul(diction, node): - input1 = node.input[0] - input2 = node.input[1] - input1_dim = diction[input1] - input2_dim = diction[input2] - out_size = np.append(input1_dim[0:-1], input2_dim[-1]) - output_name = node.output[0] - macs = counter_MatMul(input1_dim, out_size[-2:]) - return macs, out_size, output_name - - -def onnx_counter_Add(diction, node): - if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: - out_size = diction[node.input[1]] - else: - out_size = diction[node.input[0]] - output_name = node.output[0] - macs = counter_zero_ops() - return macs, out_size, output_name - - -def onnx_counter_Conv(diction, node): - # print(node) - # bias,kernelsize,outputsize - for i in node.input: - if('bias' in i): - dim_bias = diction[i] - if('weight' in i): - dim_weight = diction[i] # cout, cin,kw,kh - # print(dim_weight,dim_bias) - for attr in node.attribute: - # print(attr) - if(attr.name == 'kernel_shape'): - dim_kernel = attr.ints # kw,kh - if(attr.name == 'strides'): - dim_stride = attr.ints - if(attr.name == 'pads'): - dim_pad = attr.ints - if(attr.name == 'dilations'): - dim_dil = attr.ints - # print(dim_dil) - dim_input = diction[node.input[0]] - output_size = np.append( - dim_input[0:-np.array(dim_kernel).size-1], dim_weight[0]) - hw = dim_input[-np.array(dim_kernel).size:] - for i in range(hw.size): - hw[i] = int((hw[i]+2*dim_pad[i]-dim_dil[i] * - (dim_kernel[i]-1))/dim_stride[i]) - output_size = np.append(output_size, hw) - # print(output_size) - #print(np.prod(dim_bias), np.prod(dim_kernel), np.prod(output_size)) - macs = counter_conv(np.prod(dim_bias), np.prod( - dim_kernel), np.prod(output_size)) - output_name = node.output[0] - return macs, output_size, output_name - - -def onnx_counter_Constant(diction, node): - # print(node) - macs = counter_zero_ops() - output_name = node.output[0] - output_size = [1] - #print(macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_Mul(diction, node): - if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: - input_size = diction[node.input[1]] - else: - input_size = diction[node.input[0]] - macs = counter_Mul(np.prod(input_size)) - output_size = diction[node.input[0]] - output_name = node.output[0] - return macs, output_size, output_name - - -def onnx_counter_bn(diction, node): - input_size = diction[node.input[0]] - macs = counter_norm(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - return macs, output_size, output_name - - -def onnx_counter_relu(diction, node): - input_size = diction[node.input[0]] - macs = counter_relu(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - return macs, output_size, output_name - - -def onnx_counter_reducemean(diction, node): - input_size = diction[node.input[0]] - macs = counter_zero_ops() - output_name = node.output[0] - output_size = input_size - #print("reduce",macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_sub(diction, node): - input_size = diction[node.input[0]] - macs = counter_zero_ops() - output_name = node.output[0] - output_size = input_size - #print("sub",macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_pow(diction, node): - if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: - input_size = diction[node.input[1]] - else: - input_size = diction[node.input[0]] - macs = counter_pow(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - #print("pow",macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_sqrt(diction, node): - input_size = diction[node.input[0]] - macs = counter_sqrt(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - #print("sqrt",macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_div(diction, node): - if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: - input_size = diction[node.input[1]] - else: - input_size = diction[node.input[0]] - macs = counter_div(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - #print("div",macs, output_size, output_name) - return macs, output_size, output_name - - -def onnx_counter_instance(diction, node): - input_size = diction[node.input[0]] - macs = counter_norm(np.prod(input_size)) - output_name = node.output[0] - output_size = input_size - return macs, output_size, output_name - - -def onnx_counter_softmax(diction, node): - input_size = diction[node.input[0]] - dim = node.attribute[0].i - nfeatures = input_size[dim] - batch_size = np.prod(input_size) / nfeatures - macs = counter_softmax(nfeatures, batch_size) - output_name = node.output[0] - output_size = input_size - #print("soft",macs, output_size, output_name) - return macs, output_size, output_name - - -onnx_operators = { - 'MatMul': onnx_counter_MatMul, - 'Add': onnx_counter_Add, - 'Conv': onnx_counter_Conv, - 'Mul': onnx_counter_Mul, - 'Constant': onnx_counter_Constant, - 'BatchNormalization': onnx_counter_bn, - 'Relu': onnx_counter_relu, - 'ReduceMean': onnx_counter_reducemean, - 'Sub': onnx_counter_sub, - 'Pow': onnx_counter_pow, - 'Sqrt': onnx_counter_sqrt, - 'Div': onnx_counter_div, - 'InstanceNormalization': onnx_counter_instance, - 'Softmax': onnx_counter_softmax, - None: None, -} From 0afd8aac1903a7a7e592f2fa8c8e1a7bf8bd34f2 Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Tue, 5 Oct 2021 20:10:19 +0800 Subject: [PATCH 15/17] adding transformer --- benchmark/evaluate_transformer.py | 26 ++++++++++ thop/profile.py | 2 +- thop/rnn_hooks.py | 81 +++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 benchmark/evaluate_transformer.py diff --git a/benchmark/evaluate_transformer.py b/benchmark/evaluate_transformer.py new file mode 100644 index 0000000..2f08bb4 --- /dev/null +++ b/benchmark/evaluate_transformer.py @@ -0,0 +1,26 @@ + +import torch.nn as nn +from thop import profile +import torch + +src = torch.rand((1, 1, 10)) # S,N,x + + +class Model_transformer(nn.Module): + def __init__(self): + super(Model_transformer, self).__init__() + self.linear1 = nn.Linear(10, 512) + self.linear2 = nn.Linear(10, 512) + self.transform = nn.Transformer( + d_model=512, nhead=8, num_encoder_layers=6) + + def forward(self, input): + input1 = self.linear1(input) + input2 = self.linear2(input) + output = self.transform(input1, input2) + return output + + +model = Model_transformer() +macs, params = profile(model, inputs=(src, )) +print(macs, params) diff --git a/thop/profile.py b/thop/profile.py index 656a869..55d1db9 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -67,7 +67,7 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, - + nn.Transformer: count_Transformer, nn.Sequential: zero_ops, } diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index c00fd47..4998df9 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -196,3 +196,84 @@ def count_lstm(m: nn.LSTM, x, y): total_ops *= batch_size m.total_ops += torch.DoubleTensor([int(total_ops)]) + + +def count_Transformer(m: nn.Transformer, x, y): + total_ops = 0 + src, tgt = x + if m.batch_first: + num_steps = src.shape[0] + target = tgt.shape[1] + sequence = src.shape[1] + embedding = src.shape[2] + else: + target = tgt.shape[0] + sequence = src.shape[0] + num_steps = src.shape[1] + embedding = src.shape[2] + num_head = m.nhead + encoder_layers = m.encoder.num_layers + decoder_layers = m.decoder.num_layers + # dim_forward(default = 2048) + forward = m.encoder.layers[0].linear1.out_features + total_ops = 0 + + def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): + if bool1 == 0: + # linear_q,linear_k,linear_v all N,S,E + total_multi = 3 * sequence * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (sequence ** 4 * (embedding/num_head) ** 2 + + sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head + # linear + total_multi += sequence * embedding ** 2 + # layernorm + total_multi += 2 * sequence * embedding + elif bool1 == 1: + # linear_q,linear_k,linear_v + total_multi = 3 * target * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (target ** 4 * (embedding/num_head) ** 2 + + target ** 2 + target * (3 * target-1) + 1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding + elif bool1 == 2: + # linear_q,linear_k,linear_v + total_multi = embedding ** 2 * (2 * sequence + target) + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 + + target * sequence + target * (3 * sequence - 1)+1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding + # number of heads and batchsize + total_multi *= num_steps + return total_multi + + def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): + total_en = 0 + total_en += MultiheadAttention(0, num_head, + num_steps, target, sequence, embedding) + # fed_forward(2 conv1d) + total_en += num_steps * sequence * forward * embedding + total_en += num_steps * sequence * embedding * forward + # norm1 + total_en += 2 * num_steps * embedding * sequence + return total_en + + def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): + total_de = 0 + total_de += MultiheadAttention(1, num_head, + num_steps, target, sequence, embedding) + total_de += MultiheadAttention(2, num_head, + num_steps, target, sequence, embedding) + # linear1 linear2 fft + total_de += num_steps * target * forward * embedding + total_de += num_steps * target * embedding * forward + # layernorm + total_de += 2 * num_steps * embedding * target + return total_de + total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ + decoder_layers * \ + TransformerDecoderLayer(num_head, num_steps, + target, sequence, embedding) + m.total_ops += torch.DoubleTensor([int(total_ops)]) From e90599da5a3befa86aa6539b832b609740ac2384 Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Wed, 13 Oct 2021 20:46:03 +0800 Subject: [PATCH 16/17] fix function names --- thop/profile.py | 2 +- thop/rnn_hooks.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/thop/profile.py b/thop/profile.py index 2603a76..6944781 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -67,7 +67,7 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, - nn.Transformer: count_Transformer, + nn.Transformer: count_transformer, nn.Sequential: zero_ops, } diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index 4998df9..6238bf6 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -198,7 +198,7 @@ def count_lstm(m: nn.LSTM, x, y): m.total_ops += torch.DoubleTensor([int(total_ops)]) -def count_Transformer(m: nn.Transformer, x, y): +def count_transformer(m: nn.Transformer, x, y): total_ops = 0 src, tgt = x if m.batch_first: @@ -218,7 +218,7 @@ def count_Transformer(m: nn.Transformer, x, y): forward = m.encoder.layers[0].linear1.out_features total_ops = 0 - def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): + def multihead_attention(bool1, num_head, num_steps, target, sequence, embedding): if bool1 == 0: # linear_q,linear_k,linear_v all N,S,E total_multi = 3 * sequence * embedding ** 2 @@ -249,9 +249,9 @@ def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): total_multi *= num_steps return total_multi - def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): + def transformer_encoder_layer(num_head, num_steps, target, sequence, embedding): total_en = 0 - total_en += MultiheadAttention(0, num_head, + total_en += multihead_attention(0, num_head, num_steps, target, sequence, embedding) # fed_forward(2 conv1d) total_en += num_steps * sequence * forward * embedding @@ -260,11 +260,11 @@ def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): total_en += 2 * num_steps * embedding * sequence return total_en - def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): + def transformer_decoder_layer(num_head, num_steps, target, sequence, embedding): total_de = 0 - total_de += MultiheadAttention(1, num_head, + total_de += multihead_attention(1, num_head, num_steps, target, sequence, embedding) - total_de += MultiheadAttention(2, num_head, + total_de += multihead_attention(2, num_head, num_steps, target, sequence, embedding) # linear1 linear2 fft total_de += num_steps * target * forward * embedding @@ -272,8 +272,8 @@ def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): # layernorm total_de += 2 * num_steps * embedding * target return total_de - total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ + total_ops = encoder_layers * transformer_encoder_layer(num_head, num_steps, target, sequence, embedding) + \ decoder_layers * \ - TransformerDecoderLayer(num_head, num_steps, + transformer_decoder_layer(num_head, num_steps, target, sequence, embedding) m.total_ops += torch.DoubleTensor([int(total_ops)]) From a9181d0c405c09dcc31ade59b4df1b18ee829d2c Mon Sep 17 00:00:00 2001 From: Cat Beta Date: Wed, 13 Oct 2021 20:52:56 +0800 Subject: [PATCH 17/17] fix function names --- benchmark/evaluate_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark/evaluate_transformer.py b/benchmark/evaluate_transformer.py index 2f08bb4..4a83f27 100644 --- a/benchmark/evaluate_transformer.py +++ b/benchmark/evaluate_transformer.py @@ -6,9 +6,9 @@ src = torch.rand((1, 1, 10)) # S,N,x -class Model_transformer(nn.Module): +class ModelTransformer(nn.Module): def __init__(self): - super(Model_transformer, self).__init__() + super(ModelTransformer, self).__init__() self.linear1 = nn.Linear(10, 512) self.linear2 = nn.Linear(10, 512) self.transform = nn.Transformer( @@ -21,6 +21,6 @@ def forward(self, input): return output -model = Model_transformer() +model = ModelTransformer() macs, params = profile(model, inputs=(src, )) print(macs, params)