Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable length intermediate conv padding check #2231

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,29 @@ def _examine_intermediate_padding_in_op_subset(op_subset):
If both 1st and 2nd conv have paddings, add conv2 to inter_pad_node_list.
"""
if len(op_subset) == 4:
conv1, _, _, conv2 = op_subset
conv1, _, _, next_node = op_subset
else:
conv1, _, conv2 = op_subset
conv1, _, next_node = op_subset

conv1_padding = sum(conv1.get_module().padding)
conv2_padding = sum(conv2.get_module().padding)
if conv1_padding and conv2_padding:
inter_pad_op_list.append(conv2)
previous_padding = sum(conv1.get_module().padding)

# Examine all following nodes, while ignore activations and break if meets none conv node.
while next_node:
if next_node.type in _support_conv_op_type:
current_padding = sum(next_node.get_module().padding)

if previous_padding and current_padding:
inter_pad_op_list.append(next_node)

previous_padding = previous_padding or current_padding

next_outputs = next_node.output_ops

# Break if next_outputs has more than 1 output or not a activation/conv.
if len(next_outputs) != 1 or (next_outputs[0].type not in _support_activation_op_type and next_outputs[0].type not in _support_conv_op_type):
break
else:
next_node = next_outputs[0]

_support_activation_op_type = ("Relu", "Tanh", "HardSwish")
_support_conv_op_type = ("Conv", "Conv2D")
Expand Down
22 changes: 18 additions & 4 deletions TrainingExtensions/torch/test/python/test_arch_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self):
self.bn2 = torch.nn.BatchNorm2d(32)
self.conv4 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False)

# conv5 has intermediate paddings when consider (conv3, conv4, conv5)
# conv6 has no intermediate paddings
self.conv5 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=2, bias=False)
self.relu3 = torch.nn.ReLU()
Expand Down Expand Up @@ -129,17 +130,19 @@ def __init__(self):
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False)

# conv4 has no intermediate paddings since conv3 has no paddings
# conv4 has intermediate paddings consider (conv1, conv2, conv3, conv4)
self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=0, bias=False)
self.relu2 = torch.nn.ReLU()
self.conv4 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False)


# PReLU is not a supported activation, stop examining the nodes after for (conv1, relu1, conv2) pattern.
self.prelu = torch.nn.PReLU()
# conv6 has no intermediate paddings
self.conv5 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=2, bias=False)
self.relu3 = torch.nn.ReLU()
self.conv6 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False)

# conv7 has no intermediate paddings
# conv7, conv8 has no intermediate paddings
self.conv7 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False)
self.relu4 = torch.nn.ReLU()
self.conv8 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False)
Expand All @@ -153,6 +156,7 @@ def forward(self, x):
x = self.relu2(x)
x = self.conv4(x)

x = self.prelu(x)
x = self.conv5(x)
x = self.relu3(x)
x = self.conv6(x)
Expand Down Expand Up @@ -198,14 +202,17 @@ def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.prelu(x) # Break point for variable length

x = self.conv3(x)
x = self.tanh(x)
x = self.conv4(x)
x = self.prelu(x) # Break point for variable length

x = self.conv5(x)
x = self.hardswich(x)
x = self.conv6(x)
x = self.prelu(x) # Break point for variable length

x = self.conv7(x)
x = self.prelu(x)
Expand Down Expand Up @@ -324,6 +331,7 @@ def test_intermediate_padding(self):
ArchChecker.check_model_arch(model, self.dummy_input)
arch_checker_report = ArchChecker._arch_checker_report
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_with_BN.conv2"].failed_checks
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_with_BN.conv5"].failed_checks
assert "Model_inter_pad_with_BN.conv4" not in arch_checker_report.raw_report
assert "Model_inter_pad_with_BN.conv6" not in arch_checker_report.raw_report
assert "Model_inter_pad_with_BN.conv8" not in arch_checker_report.raw_report
Expand All @@ -334,18 +342,24 @@ def test_intermediate_padding(self):
ArchChecker.check_model_arch(model, self.dummy_input)
arch_checker_report = ArchChecker._arch_checker_report
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_without_BN.conv2"].failed_checks
assert "Model_inter_pad_without_BN.conv4" not in arch_checker_report.raw_report
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_without_BN.conv4"].failed_checks
assert "Model_inter_pad_without_BN.conv6" not in arch_checker_report.raw_report
assert "Model_inter_pad_without_BN.conv8" not in arch_checker_report.raw_report
arch_checker_report.reset_raw_report()

model = Model_inter_pad_act_type()
ArchChecker.check_model_arch(model, self.dummy_input)
arch_checker_report = ArchChecker._arch_checker_report

assert "_check_intermediate_padding" not in arch_checker_report.raw_report["Model_inter_pad_act_type.conv1"].failed_checks
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv2"].failed_checks
assert "Model_inter_pad_act_type.conv3" not in arch_checker_report.raw_report
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv4"].failed_checks
assert "Model_inter_pad_act_type.conv5" not in arch_checker_report.raw_report
assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv6"].failed_checks
assert "Model_inter_pad_act_type.conv7" not in arch_checker_report.raw_report
assert "Model_inter_pad_act_type.conv8" not in arch_checker_report.raw_report

arch_checker_report.reset_raw_report()

filepath = ArchChecker._arch_checker_report._get_write_path(".html")
Expand Down