Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer counter to profile #149

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
26 changes: 26 additions & 0 deletions benchmark/evaluate_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

import torch.nn as nn
from thop import profile
import torch

src = torch.rand((1, 1, 10)) # S,N,x


class ModelTransformer(nn.Module):
def __init__(self):
super(ModelTransformer, self).__init__()
self.linear1 = nn.Linear(10, 512)
self.linear2 = nn.Linear(10, 512)
self.transform = nn.Transformer(
d_model=512, nhead=8, num_encoder_layers=6)

def forward(self, input):
input1 = self.linear1(input)
input2 = self.linear2(input)
output = self.transform(input1, input2)
return output


model = ModelTransformer()
macs, params = profile(model, inputs=(src, ))
print(macs, params)
2 changes: 2 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
m = torch.nn.Conv2d(128, 128, 1)
x = torch.randn(1, 128, 16, 16)


flops = thop.profile(m, inputs=(x,), verbose=True)
fprint(flops)

2 changes: 1 addition & 1 deletion thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk))
nn.RNN: count_rnn,
nn.GRU: count_gru,
nn.LSTM: count_lstm,

nn.Transformer: count_transformer,
nn.Sequential: zero_ops,
}

Expand Down
81 changes: 81 additions & 0 deletions thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,84 @@ def count_lstm(m: nn.LSTM, x, y):
total_ops *= batch_size

m.total_ops += torch.DoubleTensor([int(total_ops)])


def count_transformer(m: nn.Transformer, x, y):
total_ops = 0
src, tgt = x
if m.batch_first:
num_steps = src.shape[0]
target = tgt.shape[1]
sequence = src.shape[1]
embedding = src.shape[2]
else:
target = tgt.shape[0]
sequence = src.shape[0]
num_steps = src.shape[1]
embedding = src.shape[2]
num_head = m.nhead
encoder_layers = m.encoder.num_layers
decoder_layers = m.decoder.num_layers
# dim_forward(default = 2048)
forward = m.encoder.layers[0].linear1.out_features
total_ops = 0

def multihead_attention(bool1, num_head, num_steps, target, sequence, embedding):
if bool1 == 0:
# linear_q,linear_k,linear_v all N,S,E
total_multi = 3 * sequence * embedding ** 2
# self_attn softmax(Q*K_T/sqrt(dk))*V
total_multi += (sequence ** 4 * (embedding/num_head) ** 2 +
sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head
# linear
total_multi += sequence * embedding ** 2
# layernorm
total_multi += 2 * sequence * embedding
elif bool1 == 1:
# linear_q,linear_k,linear_v
total_multi = 3 * target * embedding ** 2
# self_attn softmax(Q*K_T/sqrt(dk))*V
total_multi += (target ** 4 * (embedding/num_head) ** 2 +
target ** 2 + target * (3 * target-1) + 1) * num_head
total_multi += target * embedding ** 2
total_multi += 2 * target * embedding
elif bool1 == 2:
# linear_q,linear_k,linear_v
total_multi = embedding ** 2 * (2 * sequence + target)
# self_attn softmax(Q*K_T/sqrt(dk))*V
total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 +
target * sequence + target * (3 * sequence - 1)+1) * num_head
total_multi += target * embedding ** 2
total_multi += 2 * target * embedding
# number of heads and batchsize
total_multi *= num_steps
return total_multi

def transformer_encoder_layer(num_head, num_steps, target, sequence, embedding):
total_en = 0
total_en += multihead_attention(0, num_head,
num_steps, target, sequence, embedding)
# fed_forward(2 conv1d)
total_en += num_steps * sequence * forward * embedding
total_en += num_steps * sequence * embedding * forward
# norm1
total_en += 2 * num_steps * embedding * sequence
return total_en

def transformer_decoder_layer(num_head, num_steps, target, sequence, embedding):
total_de = 0
total_de += multihead_attention(1, num_head,
num_steps, target, sequence, embedding)
total_de += multihead_attention(2, num_head,
num_steps, target, sequence, embedding)
# linear1 linear2 fft
total_de += num_steps * target * forward * embedding
total_de += num_steps * target * embedding * forward
# layernorm
total_de += 2 * num_steps * embedding * target
return total_de
total_ops = encoder_layers * transformer_encoder_layer(num_head, num_steps, target, sequence, embedding) + \
decoder_layers * \
transformer_decoder_layer(num_head, num_steps,
target, sequence, embedding)
m.total_ops += torch.DoubleTensor([int(total_ops)])