diff --git a/README.md b/README.md index 2767e9b..4f00e49 100644 --- a/README.md +++ b/README.md @@ -156,7 +156,7 @@ image classifier implement in pytoch. type: string, default: acc, choices:['loss', 'acc', 'mean_acc'] 根据metrice选择的指标来进行保存best.pt. - **patience** - type: int, default:30 + type: int, default:30 早停法中的patience.(设置为0即为不使用早停法) - **imagenet_meanstd** default:False @@ -288,7 +288,7 @@ image classifier implement in pytoch. Example: transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=20),]) 自定义的数据增强. - **export.py** - 导出模型的文件.目前支持torchscript,onnx. + 导出模型的文件.目前支持torchscript,onnx,tensorrt. 参数解释: - **save_path** type: string, default: runs/exp @@ -334,7 +334,7 @@ image classifier implement in pytoch. | densenet | densenet121,densenet161,densenet169,densenet201 | | vgg | vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn | | efficientnet | efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,efficientnet_b4,efficientnet_b5,efficientnet_b6,efficientnet_b7
efficientnet_v2_s,efficientnet_v2_m,efficientnet_v2_l | - | nasnet | mnasnet0_5,mnasnet1_0 | + | nasnet | mnasnet1_0 | | vovnet | vovnet39,vovnet59 | | convnext | convnext_tiny,convnext_small,convnext_base,convnext_large,convnext_xlarge | | ghostnet | ghostnet | @@ -343,6 +343,7 @@ image classifier implement in pytoch. | darknet | darknet53,darknetaa53 | | cspnet | cspresnet50,cspresnext50,cspdarknet53,cs3darknet_m,cs3darknet_l,cs3darknet_x,cs3darknet_focus_m,cs3darknet_focus_l
cs3sedarknet_l,cs3sedarknet_x,cs3edgenet_x,cs3se_edgenet_x | | dpn | dpn68,dpn68b,dpn92,dpn98,dpn107,dpn131 | + | repghost | repghostnet_0_5x,repghostnet_0_58x,repghostnet_0_8x,repghostnet_1_0x,repghostnet_1_11x
repghostnet_1_3x,repghostnet_1_5x,repghostnet_2_0x | @@ -610,7 +611,7 @@ image classifier implement in pytoch. - [ ] Accumulation Gradient - [ ] Model Ensembling - [ ] Freeze Training -- [ ] Support Fuse Conv and Bn +- [x] Support Fuse Conv and Bn - [x] Early Stop diff --git a/commad.txt b/commad.txt new file mode 100644 index 0000000..6f31bef --- /dev/null +++ b/commad.txt @@ -0,0 +1,19 @@ +python main.py --model_name efficientnet_v2_l --config config/config.py --batch_size 32 --Augment AutoAugment --save_path runs/efficientnet_v2_l --device 0 \ + --pretrained --amp --warmup --ema --imagenet_meanstd + +python main.py --model_name resnext50 --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/resnext50 --device 1 \ + --pretrained --amp --warmup --ema --imagenet_meanstd + +python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0 +python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0 --half + +python metrice.py --task fps --save_path runs/resnext50 --batch_size 32 --device 0 +python metrice.py --task fps --save_path runs/resnext50 --batch_size 32 --device 0 --half + +python export.py --save_path runs/efficientnet_v2_l --export onnx --simplify --batch_size 1 +python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0 --model_type onnx + +python export.py --save_path runs/resnext50 --export onnx --simplify --batch_size 1 +python metrice.py --task fps --save_path runs/resnext50 --batch_size 1 --device 0 --model_type onnx + +python predict.py --source dataset/test/0000 --save_path runs/resnext50 --half --device 0 \ No newline at end of file diff --git a/config/__pycache__/config.cpython-38.pyc b/config/__pycache__/config.cpython-38.pyc index bcc2bcf..f6c4821 100644 Binary files a/config/__pycache__/config.cpython-38.pyc and b/config/__pycache__/config.cpython-38.pyc differ diff --git a/export.py b/export.py index 38339be..8dbe92d 100644 --- a/export.py +++ b/export.py @@ -3,7 +3,7 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import torch import torch.nn as nn -from utils.utils import select_device +from utils.utils import select_device, model_fuse def export_torchscript(opt, model, img, prefix='TorchScript'): print('Starting TorchScript export with pytorch %s...' % torch.__version__) @@ -104,6 +104,7 @@ def parse_opt(): assert not opt.dynamic, '--half not compatible with --dynamic' ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) model = ckpt['model'].float().to(DEVICE) + model_fuse(model) img = torch.rand((opt.batch_size, opt.image_channel, opt.image_size, opt.image_size)).to(DEVICE) return opt, (model.half() if opt.half else model), (img.half() if opt.half else img), DEVICE diff --git a/metrice.py b/metrice.py index 457091a..fe7aba9 100644 --- a/metrice.py +++ b/metrice.py @@ -6,7 +6,7 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import numpy as np from utils import utils_aug -from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device +from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device, model_fuse torch.backends.cudnn.deterministic = True def set_seed(seed): @@ -56,7 +56,8 @@ def parse_opt(): fps_arr = [] for i in tqdm.tqdm(range(test_time + warm_up)): since = time.time() - model(inputs) + with torch.inference_mode(): + model(inputs) if i > warm_up: fps_arr.append(time.time() - since) fps = np.mean(fps_arr) @@ -85,7 +86,7 @@ def parse_opt(): if __name__ == '__main__': opt, model, test_dataset, DEVICE, CLASS_NUM, label, save_path = parse_opt() y_true, y_pred, y_score, y_feature, img_path = [], [], [], [], [] - with torch.no_grad(): + with torch.inference_mode(): for x, y, path in tqdm.tqdm(test_dataset, desc='Test Stage'): x = (x.half().to(DEVICE) if opt.half else x.to(DEVICE)) if opt.test_tta: diff --git a/model/__init__.py b/model/__init__.py index 0cbc7c7..ca792d1 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -13,4 +13,5 @@ from .repvgg import * from .sequencer import * from .cspnet import * -from .dpn import * \ No newline at end of file +from .dpn import * +from .repghost import * \ No newline at end of file diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc index 242e622..177d5b3 100644 Binary files a/model/__pycache__/__init__.cpython-38.pyc and b/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/__pycache__/convnext.cpython-38.pyc b/model/__pycache__/convnext.cpython-38.pyc index ffc943e..37ce41d 100644 Binary files a/model/__pycache__/convnext.cpython-38.pyc and b/model/__pycache__/convnext.cpython-38.pyc differ diff --git a/model/__pycache__/cspnet.cpython-38.pyc b/model/__pycache__/cspnet.cpython-38.pyc index 88d3bcc..aa6c3d8 100644 Binary files a/model/__pycache__/cspnet.cpython-38.pyc and b/model/__pycache__/cspnet.cpython-38.pyc differ diff --git a/model/__pycache__/densenet.cpython-38.pyc b/model/__pycache__/densenet.cpython-38.pyc index ee1bd52..448ea6b 100644 Binary files a/model/__pycache__/densenet.cpython-38.pyc and b/model/__pycache__/densenet.cpython-38.pyc differ diff --git a/model/__pycache__/dpn.cpython-38.pyc b/model/__pycache__/dpn.cpython-38.pyc index eb068ec..2930103 100644 Binary files a/model/__pycache__/dpn.cpython-38.pyc and b/model/__pycache__/dpn.cpython-38.pyc differ diff --git a/model/__pycache__/efficientnetv2.cpython-38.pyc b/model/__pycache__/efficientnetv2.cpython-38.pyc index eabe70a..3e9a487 100644 Binary files a/model/__pycache__/efficientnetv2.cpython-38.pyc and b/model/__pycache__/efficientnetv2.cpython-38.pyc differ diff --git a/model/__pycache__/ghostnet.cpython-38.pyc b/model/__pycache__/ghostnet.cpython-38.pyc index 9d5452d..ad16ffd 100644 Binary files a/model/__pycache__/ghostnet.cpython-38.pyc and b/model/__pycache__/ghostnet.cpython-38.pyc differ diff --git a/model/__pycache__/mnasnet.cpython-38.pyc b/model/__pycache__/mnasnet.cpython-38.pyc index 4919f15..04e2329 100644 Binary files a/model/__pycache__/mnasnet.cpython-38.pyc and b/model/__pycache__/mnasnet.cpython-38.pyc differ diff --git a/model/__pycache__/mobilenetv2.cpython-38.pyc b/model/__pycache__/mobilenetv2.cpython-38.pyc index 1ae2b42..ee27d93 100644 Binary files a/model/__pycache__/mobilenetv2.cpython-38.pyc and b/model/__pycache__/mobilenetv2.cpython-38.pyc differ diff --git a/model/__pycache__/mobilenetv3.cpython-38.pyc b/model/__pycache__/mobilenetv3.cpython-38.pyc index 8eb4b5e..d37248a 100644 Binary files a/model/__pycache__/mobilenetv3.cpython-38.pyc and b/model/__pycache__/mobilenetv3.cpython-38.pyc differ diff --git a/model/__pycache__/repghost.cpython-38.pyc b/model/__pycache__/repghost.cpython-38.pyc new file mode 100644 index 0000000..741c13c Binary files /dev/null and b/model/__pycache__/repghost.cpython-38.pyc differ diff --git a/model/__pycache__/repvgg.cpython-38.pyc b/model/__pycache__/repvgg.cpython-38.pyc index 69592bb..92a9b21 100644 Binary files a/model/__pycache__/repvgg.cpython-38.pyc and b/model/__pycache__/repvgg.cpython-38.pyc differ diff --git a/model/__pycache__/resnest.cpython-38.pyc b/model/__pycache__/resnest.cpython-38.pyc index cc468eb..062bb52 100644 Binary files a/model/__pycache__/resnest.cpython-38.pyc and b/model/__pycache__/resnest.cpython-38.pyc differ diff --git a/model/__pycache__/resnet.cpython-38.pyc b/model/__pycache__/resnet.cpython-38.pyc index 20bdaea..72c7996 100644 Binary files a/model/__pycache__/resnet.cpython-38.pyc and b/model/__pycache__/resnet.cpython-38.pyc differ diff --git a/model/__pycache__/sequencer.cpython-38.pyc b/model/__pycache__/sequencer.cpython-38.pyc index 7f37a30..6e0729c 100644 Binary files a/model/__pycache__/sequencer.cpython-38.pyc and b/model/__pycache__/sequencer.cpython-38.pyc differ diff --git a/model/__pycache__/shufflenetv2.cpython-38.pyc b/model/__pycache__/shufflenetv2.cpython-38.pyc index 9033cb9..77d7ce9 100644 Binary files a/model/__pycache__/shufflenetv2.cpython-38.pyc and b/model/__pycache__/shufflenetv2.cpython-38.pyc differ diff --git a/model/__pycache__/vgg.cpython-38.pyc b/model/__pycache__/vgg.cpython-38.pyc index e40147a..be54e01 100644 Binary files a/model/__pycache__/vgg.cpython-38.pyc and b/model/__pycache__/vgg.cpython-38.pyc differ diff --git a/model/__pycache__/vovnet.cpython-38.pyc b/model/__pycache__/vovnet.cpython-38.pyc index e788fe7..61e939e 100644 Binary files a/model/__pycache__/vovnet.cpython-38.pyc and b/model/__pycache__/vovnet.cpython-38.pyc differ diff --git a/model/cspnet.py b/model/cspnet.py index 084ae5f..8ecff9f 100644 --- a/model/cspnet.py +++ b/model/cspnet.py @@ -7,7 +7,7 @@ from timm.models.helpers import named_apply from timm.models.layers import ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible from timm.models.registry import register_model -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn urls_dict = { 'cspresnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth', @@ -343,6 +343,18 @@ def forward(self, x): x = self.act3(x) return x + def switch_to_deploy(self): + self.conv1 = nn.Sequential( + fuse_conv_bn(self.conv1.conv, self.conv1.bn), + self.conv1.bn.act, + ) + self.conv2 = nn.Sequential( + fuse_conv_bn(self.conv2.conv, self.conv2.bn), + self.conv2.bn.act, + ) + self.conv3 = nn.Sequential( + fuse_conv_bn(self.conv3.conv, self.conv3.bn), + ) class DarkBlock(nn.Module): """ DarkNet Block @@ -383,6 +395,15 @@ def forward(self, x): x = self.drop_path(x) + shortcut return x + def switch_to_deploy(self): + self.conv1 = nn.Sequential( + fuse_conv_bn(self.conv1.conv, self.conv1.bn), + self.conv1.bn.act, + ) + self.conv2 = nn.Sequential( + fuse_conv_bn(self.conv2.conv, self.conv2.bn), + self.conv2.bn.act, + ) class EdgeBlock(nn.Module): """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output) @@ -422,6 +443,16 @@ def forward(self, x): x = self.conv2(x) x = self.drop_path(x) + shortcut return x + + def switch_to_deploy(self): + self.conv1 = nn.Sequential( + fuse_conv_bn(self.conv1.conv, self.conv1.bn), + self.conv1.bn.act, + ) + self.conv2 = nn.Sequential( + fuse_conv_bn(self.conv2.conv, self.conv2.bn), + self.conv2.bn.act, + ) class CrossStage(nn.Module): @@ -500,6 +531,34 @@ def forward(self, x): out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out + def switch_to_deploy(self): + if type(self.conv_down) is nn.Sequential: + self.conv_down = nn.Sequential( + self.conv_down[0], + fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn), + self.conv_down[1].bn.act, + self.conv_down[1].aa + ) + elif type(self.conv_down) is nn.Identity: + pass + else: + self.conv_down = nn.Sequential( + fuse_conv_bn(self.conv_down.conv, self.conv_down.bn), + self.conv_down.bn.act, + self.conv_down.aa + ) + self.conv_exp = nn.Sequential( + fuse_conv_bn(self.conv_exp.conv, self.conv_exp.bn), + self.conv_exp.bn.act, + ) + self.conv_transition_b = nn.Sequential( + fuse_conv_bn(self.conv_transition_b.conv, self.conv_transition_b.bn), + self.conv_transition_b.bn.act, + ) + self.conv_transition = nn.Sequential( + fuse_conv_bn(self.conv_transition.conv, self.conv_transition.bn), + self.conv_transition.bn.act, + ) class CrossStage3(nn.Module): """Cross Stage 3. @@ -575,6 +634,29 @@ def forward(self, x): out = self.conv_transition(torch.cat([x1, x2], dim=1)) return out + def switch_to_deploy(self): + if self.conv_down is not None: + if type(self.conv_down) is nn.Sequential: + self.conv_down = nn.Sequential( + self.conv_down[0], + fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn), + self.conv_down[1].bn.act, + self.conv_down[1].aa + ) + else: + self.conv_down = nn.Sequential( + fuse_conv_bn(self.conv_down.conv, self.conv_down.bn), + self.conv_down.bn.act, + self.conv_down.aa + ) + self.conv_exp = nn.Sequential( + fuse_conv_bn(self.conv_exp.conv, self.conv_exp.bn), + self.conv_exp.bn.act, + ) + self.conv_transition = nn.Sequential( + fuse_conv_bn(self.conv_transition.conv, self.conv_transition.bn), + self.conv_transition.bn.act, + ) class DarkStage(nn.Module): """DarkNet stage.""" @@ -630,6 +712,20 @@ def forward(self, x): x = self.blocks(x) return x + def switch_to_deploy(self): + if type(self.conv_down) is nn.Sequential: + self.conv_down = nn.Sequential( + self.conv_down[0], + fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn), + self.conv_down[1].bn.act, + self.conv_down[1].aa + ) + else: + self.conv_down = nn.Sequential( + fuse_conv_bn(self.conv_down.conv, self.conv_down.bn), + self.conv_down.bn.act, + self.conv_down.aa + ) def create_csp_stem( in_chans=3, @@ -834,11 +930,11 @@ def forward_features(self, x, need_fea=False): x = layer(x) features.append(x) x = self.avgpool(x) - return features, torch.flatten(x) + return features, torch.flatten(x, start_dim=1, end_dim=3) else: x = self.stages(x) x = self.avgpool(x) - return torch.flatten(x) + return torch.flatten(x, start_dim=1, end_dim=3) def forward_head(self, x): return self.head(x) @@ -992,8 +1088,9 @@ def cs3se_edgenet_x(pretrained=False, **kwargs): return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs) if __name__ == '__main__': - inputs = torch.rand((1, 3, 224, 224)) - model = cs3se_edgenet_x(pretrained=False) + inputs = torch.rand((2, 3, 224, 224)) + model = cspresnet50(pretrained=False) + model.head = nn.Linear(model.head.in_features, 1) model.eval() out = model(inputs) print('out shape:{}'.format(out.size())) diff --git a/model/densenet.py b/model/densenet.py index 1876fd9..d8a9b94 100644 --- a/model/densenet.py +++ b/model/densenet.py @@ -8,7 +8,7 @@ from torchvision._internally_replaced_utils import load_state_dict_from_url from torch import Tensor from typing import Any, List, Tuple - +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] @@ -246,7 +246,10 @@ def forward_features(self, x, need_fea=False): def cam_layer(self): return self.features[-1] - + + def switch_to_deploy(self): + self.features.conv0 = fuse_conv_bn(self.features.conv0, self.features.norm0) + del self.features.norm0 def load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer @@ -262,15 +265,16 @@ def load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] - model_dict = model.state_dict() - weight_dict = {} - for k, v in state_dict.items(): - if k in model_dict: - if np.shape(model_dict[k]) == np.shape(v): - weight_dict[k] = v - pretrained_dict = weight_dict - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) + load_weights_from_state_dict(model, state_dict) + # model_dict = model.state_dict() + # weight_dict = {} + # for k, v in state_dict.items(): + # if k in model_dict: + # if np.shape(model_dict[k]) == np.shape(v): + # weight_dict[k] = v + # pretrained_dict = weight_dict + # model_dict.update(pretrained_dict) + # model.load_state_dict(model_dict) def _densenet( arch: str, diff --git a/model/dpn.py b/model/dpn.py index 8b75748..d2f246f 100644 --- a/model/dpn.py +++ b/model/dpn.py @@ -51,7 +51,6 @@ def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=Ba def forward(self, x): return self.conv(self.bn(x)) - class DualPathBlock(nn.Module): def __init__( self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): diff --git a/model/efficientnetv2.py b/model/efficientnetv2.py index 6d91d7c..acdbc76 100644 --- a/model/efficientnetv2.py +++ b/model/efficientnetv2.py @@ -15,7 +15,7 @@ from torchvision.models._api import WeightsEnum, Weights from torchvision.models._meta import _IMAGENET_CATEGORIES from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = [ "efficientnet_b0", @@ -155,6 +155,17 @@ def forward(self, input: Tensor) -> Tensor: result = self.stochastic_depth(result) result += input return result + + def switch_to_deploy(self): + new_block = [] + for layer in self.block: + if type(layer) is Conv2dNormActivation: + new_block.append(fuse_conv_bn(layer[0], layer[1])) + if len(layer) > 2: + new_block.append(layer[2]) + else: + new_block.append(layer) + self.block = nn.Sequential(*new_block) class FusedMBConv(nn.Module): @@ -216,6 +227,17 @@ def forward(self, input: Tensor) -> Tensor: result = self.stochastic_depth(result) result += input return result + + def switch_to_deploy(self): + new_block = [] + for layer in self.block: + if type(layer) is Conv2dNormActivation: + new_block.append(fuse_conv_bn(layer[0], layer[1])) + if len(layer) > 2: + new_block.append(layer[2]) + else: + new_block.append(layer) + self.block = nn.Sequential(*new_block) class EfficientNet(nn.Module): @@ -362,6 +384,17 @@ def forward_features(self, x, need_fea=False): def cam_layer(self): return self.features[-1] + + def switch_to_deploy(self): + new_block = [] + for layer in self.features: + if type(layer) is Conv2dNormActivation: + new_block.append(fuse_conv_bn(layer[0], layer[1])) + if len(layer) > 2: + new_block.append(layer[2]) + else: + new_block.append(layer) + self.features = nn.Sequential(*new_block) def _efficientnet( inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], diff --git a/model/ghostnet.py b/model/ghostnet.py index d184c70..fc30863 100644 --- a/model/ghostnet.py +++ b/model/ghostnet.py @@ -3,7 +3,7 @@ import math import numpy as np from torch.hub import load_state_dict_from_url -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ['ghostnet'] @@ -72,6 +72,15 @@ def forward(self, x): out = torch.cat([x1,x2], dim=1) return out[:,:self.oup,:,:] + def switch_to_deploy(self): + self.primary_conv = nn.Sequential( + fuse_conv_bn(self.primary_conv[0], self.primary_conv[1]), + self.primary_conv[2] + ) + self.cheap_operation = nn.Sequential( + fuse_conv_bn(self.cheap_operation[0], self.cheap_operation[1]), + self.cheap_operation[2] + ) class GhostBottleneck(nn.Module): def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se): @@ -101,6 +110,27 @@ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se): def forward(self, x): return self.conv(x) + self.shortcut(x) + def switch_to_deploy(self): + if len(self.conv[1]) > 0: + self.conv = nn.Sequential( + self.conv[0], + fuse_conv_bn(self.conv[1][0], self.conv[1][1]), + self.conv[1][2], + self.conv[2], + self.conv[3], + ) + else: + self.conv = nn.Sequential( + self.conv[0], + self.conv[2], + self.conv[3], + ) + if len(self.shortcut) != 0: + self.shortcut = nn.Sequential( + fuse_conv_bn(self.shortcut[0][0], self.shortcut[0][1]), + self.shortcut[0][2], + fuse_conv_bn(self.shortcut[1], self.shortcut[2]) + ) class GhostNet(nn.Module): def __init__(self, cfgs, num_classes=1000, width_mult=1.): @@ -183,6 +213,17 @@ def _initialize_weights(self): def cam_layer(self): return self.features[-1] + + def switch_to_deploy(self): + self.features[0] = nn.Sequential( + fuse_conv_bn(self.features[0][0], self.features[0][1]), + self.features[0][2] + ) + self.squeeze = nn.Sequential( + fuse_conv_bn(self.squeeze[0], self.squeeze[1]), + self.squeeze[2], + self.squeeze[3] + ) def ghostnet(pretrained=False, **kwargs): """ diff --git a/model/mnasnet.py b/model/mnasnet.py index 9efa92b..60f8245 100644 --- a/model/mnasnet.py +++ b/model/mnasnet.py @@ -6,9 +6,9 @@ import numpy as np from torchvision._internally_replaced_utils import load_state_dict_from_url from typing import Any, Dict, List -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn -__all__ = ['mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] +__all__ = ['mnasnet1_0'] _MODEL_URLS = { "mnasnet0_5": @@ -59,6 +59,15 @@ def forward(self, input: Tensor) -> Tensor: return self.layers(input) + input else: return self.layers(input) + + def switch_to_deploy(self): + self.layers = nn.Sequential( + fuse_conv_bn(self.layers[0], self.layers[1]), + self.layers[2], + fuse_conv_bn(self.layers[3], self.layers[4]), + self.layers[5], + fuse_conv_bn(self.layers[6], self.layers[7]) + ) def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, @@ -146,6 +155,18 @@ def __init__( nn.Linear(1280, num_classes)) self._initialize_weights() + def switch_to_deploy(self): + self.layers = nn.Sequential( + fuse_conv_bn(self.layers[0], self.layers[1]), + self.layers[2], + fuse_conv_bn(self.layers[3], self.layers[4]), + self.layers[5], + fuse_conv_bn(self.layers[6], self.layers[7]), + self.layers[8:14], + fuse_conv_bn(self.layers[14], self.layers[15]), + self.layers[16] + ) + def forward(self, x: Tensor, need_fea=False) -> Tensor: if need_fea: features, features_fc = self.forward_features(x, need_fea) diff --git a/model/mobilenetv2.py b/model/mobilenetv2.py index bbff475..63d7954 100644 --- a/model/mobilenetv2.py +++ b/model/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision._internally_replaced_utils import load_state_dict_from_url from torchvision.models._utils import _make_divisible from typing import Callable, Any, Optional, List -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ['mobilenet_v2'] @@ -75,7 +75,22 @@ def forward(self, x: Tensor) -> Tensor: return x + self.conv(x) else: return self.conv(x) - + + def switch_to_deploy(self): + if len(self.conv) == 4: + self.conv = nn.Sequential( + fuse_conv_bn(self.conv[0][0], self.conv[0][1]), + self.conv[0][2], + fuse_conv_bn(self.conv[1][0], self.conv[1][1]), + self.conv[1][2], + fuse_conv_bn(self.conv[2], self.conv[3]), + ) + else: + self.conv = nn.Sequential( + fuse_conv_bn(self.conv[0][0], self.conv[0][1]), + self.conv[0][2], + fuse_conv_bn(self.conv[1], self.conv[2]), + ) class MobileNetV2(nn.Module): def __init__( @@ -199,6 +214,16 @@ def forward_features(self, x, need_fea=False): def cam_layer(self): return self.features[-1] + + def switch_to_deploy(self): + self.features[0] = nn.Sequential( + fuse_conv_bn(self.features[0][0], self.features[0][1]), + self.features[0][2] + ) + self.features[-1] = nn.Sequential( + fuse_conv_bn(self.features[-1][0], self.features[-1][1]), + self.features[-1][2] + ) def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py index de05fcd..1a2cd4b 100644 --- a/model/mobilenetv3.py +++ b/model/mobilenetv3.py @@ -7,7 +7,7 @@ from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from torchvision._internally_replaced_utils import load_state_dict_from_url from torchvision.models._utils import _make_divisible -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ["mobilenetv3_large", "mobilenetv3_small"] @@ -90,6 +90,16 @@ def forward(self, input: Tensor) -> Tensor: result += input return result + def switch_to_deploy(self): + new_layers = [] + for i in range(len(self.block)): + if type(self.block[i]) is Conv2dNormActivation: + new_layers.append(fuse_conv_bn(self.block[i][0], self.block[i][1])) + if len(self.block[i]) == 3: + new_layers.append(self.block[i][2]) + else: + new_layers.append(self.block[i]) + self.block = nn.Sequential(*new_layers) class MobileNetV3(nn.Module): @@ -164,6 +174,17 @@ def __init__( nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) + def switch_to_deploy(self): + new_layers = [] + for i in range(len(self.features)): + if type(self.features[i]) is Conv2dNormActivation: + new_layers.append(fuse_conv_bn(self.features[i][0], self.features[i][1])) + if len(self.features[i]) == 3: + new_layers.append(self.features[i][2]) + else: + new_layers.append(self.features[i]) + self.features = nn.Sequential(*new_layers) + def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: if need_fea: features, features_fc = self.forward_features(x, need_fea) diff --git a/model/repghost.py b/model/repghost.py new file mode 100644 index 0000000..3076757 --- /dev/null +++ b/model/repghost.py @@ -0,0 +1,560 @@ +import copy +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url +from utils.utils import load_weights_from_state_dict, fuse_conv_bn + +__all__ = [ + 'repghostnet_0_5x', + 'repghostnet_repid_0_5x', + 'repghostnet_norep_0_5x', + 'repghostnet_wo_0_5x', + 'repghostnet_0_58x', + 'repghostnet_0_8x', + 'repghostnet_1_0x', + 'repghostnet_1_11x', + 'repghostnet_1_3x', + 'repghostnet_1_5x', + 'repghostnet_2_0x', +] + +weights_dict = { + 'repghostnet_0_5x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_0_5x_43M_66.95.pth.tar', + 'repghostnet_0_58x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_0_58x_60M_68.94.pth.tar', + 'repghostnet_0_8x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_0_8x_96M_72.24.pth.tar', + 'repghostnet_1_0x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_1_0x_142M_74.22.pth.tar', + 'repghostnet_1_11x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_1_11x_170M_75.07.pth.tar', + 'repghostnet_1_3x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_1_3x_231M_76.37.pth.tar', + 'repghostnet_1_5x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_1_5x_301M_77.45.pth.tar', + 'repghostnet_2_0x': 'https://github.com/z1069614715/pretrained-weights/releases/download/repghost_v1.0/repghostnet_2_0x_516M_78.81.pth.tar', +} + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class SqueezeExcite(nn.Module): + def __init__( + self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + divisor=4, + **_, + ): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible( + (reduced_base_chs or in_chs) * se_ratio, divisor, + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d( + in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + if hasattr(self, 'bn1'): + x = self.bn1(x) + x = self.act1(x) + return x + + def switch_to_deploy(self): + self.conv = fuse_conv_bn(self.conv, self.bn1) + del self.bn1 + +class RepGhostModule(nn.Module): + def __init__( + self, inp, oup, kernel_size=1, dw_size=3, stride=1, relu=True, deploy=False, reparam_bn=True, reparam_identity=False + ): + super(RepGhostModule, self).__init__() + init_channels = oup + new_channels = oup + self.deploy = deploy + + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + fusion_conv = [] + fusion_bn = [] + if not deploy and reparam_bn: + fusion_conv.append(nn.Identity()) + fusion_bn.append(nn.BatchNorm2d(init_channels)) + if not deploy and reparam_identity: + fusion_conv.append(nn.Identity()) + fusion_bn.append(nn.Identity()) + + self.fusion_conv = nn.Sequential(*fusion_conv) + self.fusion_bn = nn.Sequential(*fusion_bn) + + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=deploy, + ), + nn.BatchNorm2d(new_channels) if not deploy else nn.Sequential(), + # nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + if deploy: + self.cheap_operation = self.cheap_operation[0] + if relu: + self.relu = nn.ReLU(inplace=False) + else: + self.relu = nn.Sequential() + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + for conv, bn in zip(self.fusion_conv, self.fusion_bn): + x2 = x2 + bn(conv(x1)) + return self.relu(x2) + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1]) + for conv, bn in zip(self.fusion_conv, self.fusion_bn): + kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device) + kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel) + bias3x3 += bias + return kernel3x3, bias3x3 + + @staticmethod + def _pad_1x1_to_3x3_tensor(kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + @staticmethod + def _fuse_bn_tensor(conv, bn, in_channels=None, device=None): + in_channels = in_channels if in_channels else bn.running_mean.shape[0] + device = device if device else bn.weight.device + if isinstance(conv, nn.Conv2d): + kernel = conv.weight + assert conv.bias is None + else: + assert isinstance(conv, nn.Identity) + kernel_value = np.zeros((in_channels, 1, 1, 1), dtype=np.float32) + for i in range(in_channels): + kernel_value[i, 0, 0, 0] = 1 + kernel = torch.from_numpy(kernel_value).to(device) + + if isinstance(bn, nn.BatchNorm2d): + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + assert isinstance(bn, nn.Identity) + return kernel, torch.zeros(in_channels).to(kernel.device) + + def switch_to_deploy(self): + if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0: + return + kernel, bias = self.get_equivalent_kernel_bias() + self.cheap_operation = nn.Conv2d(in_channels=self.cheap_operation[0].in_channels, + out_channels=self.cheap_operation[0].out_channels, + kernel_size=self.cheap_operation[0].kernel_size, + padding=self.cheap_operation[0].padding, + dilation=self.cheap_operation[0].dilation, + groups=self.cheap_operation[0].groups, + bias=True) + self.cheap_operation.weight.data = kernel + self.cheap_operation.bias.data = bias + self.__delattr__('fusion_conv') + self.__delattr__('fusion_bn') + self.fusion_conv = [] + self.fusion_bn = [] + self.deploy = True + + self.primary_conv = nn.Sequential( + fuse_conv_bn(self.primary_conv[0], self.primary_conv[1]), + self.primary_conv[2] + ) + + +class RepGhostBottleneck(nn.Module): + """RepGhost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + se_ratio=0.0, + shortcut=True, + reparam=True, + reparam_bn=True, + reparam_identity=False, + deploy=False, + ): + super(RepGhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + self.enable_shortcut = shortcut + self.in_chs = in_chs + self.out_chs = out_chs + + # Point-wise expansion + self.ghost1 = RepGhostModule( + in_chs, + mid_chs, + relu=True, + reparam_bn=reparam and reparam_bn, + reparam_identity=reparam and reparam_identity, + deploy=deploy, + ) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + # Point-wise linear projection + self.ghost2 = RepGhostModule( + mid_chs, + out_chs, + relu=False, + reparam_bn=reparam and reparam_bn, + reparam_identity=reparam and reparam_identity, + deploy=deploy, + ) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d( + in_chs, out_chs, 1, stride=1, + padding=0, bias=False, + ), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + # 1st repghost bottleneck + x1 = self.ghost1(x) + + # Depth-wise convolution + if self.stride > 1: + x = self.conv_dw(x1) + x = self.bn_dw(x) + else: + x = x1 + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd repghost bottleneck + x = self.ghost2(x) + if not self.enable_shortcut and self.in_chs == self.out_chs and self.stride == 1: + return x + return x + self.shortcut(residual) + + def switch_to_deploy(self): + if len(self.shortcut) != 0: + self.shortcut = nn.Sequential( + fuse_conv_bn(self.shortcut[0], self.shortcut[1]), + fuse_conv_bn(self.shortcut[2], self.shortcut[3]), + ) + + +class RepGhostNet(nn.Module): + def __init__( + self, + cfgs, + num_classes=1000, + width=1.0, + dropout=0.2, + shortcut=True, + reparam=True, + reparam_bn=True, + reparam_identity=False, + deploy=False, + ): + super(RepGhostNet, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + self.dropout = dropout + self.num_classes = num_classes + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + block = RepGhostBottleneck + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + layers.append( + block( + input_channel, + hidden_channel, + output_channel, + k, + s, + se_ratio=se_ratio, + shortcut=shortcut, + reparam=reparam, + reparam_bn=reparam_bn, + reparam_identity=reparam_identity, + deploy=deploy + ), + ) + input_channel = output_channel + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width * 2, 4) + stages.append( + nn.Sequential( + ConvBnAct(input_channel, output_channel, 1), + ), + ) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d( + input_channel, output_channel, 1, 1, 0, bias=True, + ) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x, need_fea=False): + if need_fea: + features, features_fc = self.forward_features(x, need_fea) + x = self.classifier(features_fc) + return features, features_fc, x + else: + x = self.forward_features(x) + x = self.classifier(x) + return x + + def forward_features(self, x, need_fea=False): + input_size = x.size(2) + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if need_fea: + scale = [4, 8, 16, 32] + features = [None, None, None, None] + for idx, layer in enumerate(self.blocks): + x = layer(x) + if input_size // x.size(2) in scale: + features[scale.index(input_size // x.size(2))] = x + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return features, x.view(x.size(0), -1) + else: + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x.view(x.size(0), -1) + + def convert_to_deploy(self): + repghost_model_convert(self, do_copy=False) + + +def repghost_model_convert(model:torch.nn.Module, save_path=None, do_copy=True): + """ + taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + """ + if do_copy: + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + if save_path is not None: + torch.save(model.state_dict(), save_path) + return model + + +def repghostnet(enable_se=True, pretrained=False, name=None, **kwargs): + """ + Constructs a RepGhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 8, 16, 0, 1]], + # stage2 + [[3, 24, 24, 0, 2]], + [[3, 36, 24, 0, 1]], + # stage3 + [[5, 36, 40, 0.25 if enable_se else 0, 2]], + [[5, 60, 40, 0.25 if enable_se else 0, 1]], + # stage4 + [[3, 120, 80, 0, 2]], + [ + [3, 100, 80, 0, 1], + [3, 120, 80, 0, 1], + [3, 120, 80, 0, 1], + [3, 240, 112, 0.25 if enable_se else 0, 1], + [3, 336, 112, 0.25 if enable_se else 0, 1], + ], + # stage5 + [[5, 336, 160, 0.25 if enable_se else 0, 2]], + [ + [5, 480, 160, 0, 1], + [5, 480, 160, 0.25 if enable_se else 0, 1], + [5, 480, 160, 0, 1], + [5, 480, 160, 0.25 if enable_se else 0, 1], + ], + ] + model = RepGhostNet(cfgs, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(weights_dict[name], progress=True)['state_dict_ema'] + model = load_weights_from_state_dict(model, state_dict) + return model + + +def repghostnet_0_5x(**kwargs): + return repghostnet(width=0.5, name='repghostnet_0_5x', **kwargs) + + +def repghostnet_repid_0_5x(**kwargs): + return repghostnet(width=0.5, name='repghostnet_0_5x', reparam_bn=False, reparam_identity=True, **kwargs) + + +def repghostnet_norep_0_5x(**kwargs): + return repghostnet(width=0.5, name='repghostnet_0_5x', reparam=False, **kwargs) + + +def repghostnet_wo_0_5x(**kwargs): + return repghostnet(width=0.5, name='repghostnet_0_5x', shortcut=False, **kwargs) + + +def repghostnet_0_58x(**kwargs): + return repghostnet(width=0.58, name='repghostnet_0_58x', **kwargs) + + +def repghostnet_0_8x(**kwargs): + return repghostnet(width=0.8, name='repghostnet_0_8x', **kwargs) + + +def repghostnet_1_0x(**kwargs): + return repghostnet(width=1.0, name='repghostnet_1_0x', **kwargs) + + +def repghostnet_1_11x(**kwargs): + return repghostnet(width=1.11, name='repghostnet_1_11x', **kwargs) + + +def repghostnet_1_3x(**kwargs): + return repghostnet(width=1.3, name='repghostnet_1_3x', **kwargs) + + +def repghostnet_1_5x(**kwargs): + return repghostnet(width=1.5, name='repghostnet_1_5x', **kwargs) + + +def repghostnet_2_0x(**kwargs): + return repghostnet(width=2.0, name='repghostnet_2_0x', **kwargs) + +if __name__ == '__main__': + inputs = torch.rand((1, 3, 224, 224)) + model = repghostnet_0_5x(pretrained=True) + model.eval() + out = model(inputs) + print('out shape:{}'.format(out.size())) + feas, fea_fc, out = model(inputs, True) + for idx, fea in enumerate(feas): + print('feature {} shape:{}'.format(idx + 1, fea.size())) + print('fc shape:{}'.format(fea_fc.size())) + print('out shape:{}'.format(out.size())) + + model.convert_to_deploy() \ No newline at end of file diff --git a/model/resnest.py b/model/resnest.py index 75be5c9..dfdbb9a 100644 --- a/model/resnest.py +++ b/model/resnest.py @@ -4,7 +4,7 @@ from torch.nn.modules.utils import _pair import torch.nn.functional as F import numpy as np -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] _url_format = 'https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/{}-{}.pth' @@ -73,7 +73,7 @@ def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0 def forward(self, x): x = self.conv(x) - if self.use_bn: + if self.use_bn and hasattr(self, 'bn0'): x = self.bn0(x) if self.dropblock_prob > 0.0: x = self.dropblock(x) @@ -108,6 +108,14 @@ def forward(self, x): out = atten * x return out.contiguous() + def switch_to_deploy(self): + if self.use_bn: + try: + self.conv = fuse_conv_bn(self.conv, self.bn0) + del self.bn0 + except: + pass + class Bottleneck(nn.Module): """ResNet Bottleneck """ @@ -177,7 +185,8 @@ def forward(self, x): residual = x out = self.conv1(x) - out = self.bn1(out) + if hasattr(self, 'bn1'): + out = self.bn1(out) if self.dropblock_prob > 0.0: out = self.dropblock1(out) out = self.relu(out) @@ -187,7 +196,8 @@ def forward(self, x): out = self.conv2(out) if self.radix == 0: - out = self.bn2(out) + if hasattr(self, 'bn2'): + out = self.bn2(out) if self.dropblock_prob > 0.0: out = self.dropblock2(out) out = self.relu(out) @@ -196,7 +206,8 @@ def forward(self, x): out = self.avd_layer(out) out = self.conv3(out) - out = self.bn3(out) + if hasattr(self, 'bn3'): + out = self.bn3(out) if self.dropblock_prob > 0.0: out = self.dropblock3(out) @@ -207,6 +218,15 @@ def forward(self, x): out = self.relu(out) return out + + def switch_to_deploy(self): + self.conv1 = fuse_conv_bn(self.conv1, self.bn1) + del self.bn1 + if self.radix == 0: + self.conv2 = fuse_conv_bn(self.conv2, self.bn2) + del self.bn2 + self.conv3 = fuse_conv_bn(self.conv3, self.bn3) + del self.bn3 class ResNet(nn.Module): """ResNet Variants @@ -383,7 +403,8 @@ def forward(self, x, need_fea=False): def forward_features(self, x, need_fea=False): if need_fea: x = self.conv1(x) - x = self.bn1(x) + if hasattr(self, 'bn1'): + x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) @@ -397,7 +418,8 @@ def forward_features(self, x, need_fea=False): return [x1, x2, x3, x4], x else: x = self.conv1(x) - x = self.bn1(x) + if hasattr(self, 'bn1'): + x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) @@ -413,6 +435,19 @@ def forward_features(self, x, need_fea=False): def cam_layer(self): return self.layer4 + def switch_to_deploy(self): + if type(self.conv1) is nn.Conv2d: + self.conv1 = fuse_conv_bn(self.conv1, self.bn1) + else: + self.conv1 = nn.Sequential( + fuse_conv_bn(self.conv1[0], self.conv1[1]), + self.conv1[2], + fuse_conv_bn(self.conv1[3], self.conv1[4]), + self.conv1[5], + fuse_conv_bn(self.conv1[6], self.bn1), + ) + del self.bn1 + def short_hash(name): if name not in _model_sha256: raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) diff --git a/model/resnet.py b/model/resnet.py index e4ae007..d348754 100644 --- a/model/resnet.py +++ b/model/resnet.py @@ -4,7 +4,7 @@ import numpy as np from torchvision._internally_replaced_utils import load_state_dict_from_url from typing import Type, Any, Callable, Union, List, Optional - +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', @@ -69,11 +69,13 @@ def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) - out = self.bn1(out) + if hasattr(self, 'bn1'): + out = self.bn1(out) out = self.relu(out) out = self.conv2(out) - out = self.bn2(out) + if hasattr(self, 'bn2'): + out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) @@ -83,6 +85,11 @@ def forward(self, x: Tensor) -> Tensor: return out + def switch_to_deploy(self): + self.conv1 = fuse_conv_bn(self.conv1, self.bn1) + del self.bn1 + self.conv2 = fuse_conv_bn(self.conv2, self.bn2) + del self.bn2 class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) @@ -123,15 +130,18 @@ def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) - out = self.bn1(out) + if hasattr(self, 'bn1'): + out = self.bn1(out) out = self.relu(out) out = self.conv2(out) - out = self.bn2(out) + if hasattr(self, 'bn2'): + out = self.bn2(out) out = self.relu(out) out = self.conv3(out) - out = self.bn3(out) + if hasattr(self, 'bn3'): + out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) @@ -141,6 +151,13 @@ def forward(self, x: Tensor) -> Tensor: return out + def switch_to_deploy(self): + self.conv1 = fuse_conv_bn(self.conv1, self.bn1) + del self.bn1 + self.conv2 = fuse_conv_bn(self.conv2, self.bn2) + del self.bn2 + self.conv3 = fuse_conv_bn(self.conv3, self.bn3) + del self.bn3 class ResNet(nn.Module): @@ -203,6 +220,10 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + def switch_to_deploy(self): + self.conv1 = fuse_conv_bn(self.conv1, self.bn1) + del self.bn1 + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer @@ -244,7 +265,8 @@ def forward(self, x: Tensor, need_fea=False) -> Tensor: def forward_features(self, x, need_fea=False): x = self.conv1(x) - x = self.bn1(x) + if hasattr(self, 'bn1'): + x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) if need_fea: @@ -281,15 +303,7 @@ def _resnet( if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model_dict = model.state_dict() - weight_dict = {} - for k, v in state_dict.items(): - if k in model_dict: - if np.shape(model_dict[k]) == np.shape(v): - weight_dict[k] = v - pretrained_dict = weight_dict - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) + load_weights_from_state_dict(model, state_dict) return model diff --git a/model/shufflenetv2.py b/model/shufflenetv2.py index c2e766e..df1986f 100644 --- a/model/shufflenetv2.py +++ b/model/shufflenetv2.py @@ -4,11 +4,10 @@ import torch.nn as nn from torchvision._internally_replaced_utils import load_state_dict_from_url from typing import Callable, Any, List -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = [ - 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', - 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' + 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0' ] model_urls = { @@ -96,6 +95,20 @@ def forward(self, x: Tensor) -> Tensor: return out + def switch_to_deploy(self): + if len(self.branch1) > 0: + self.branch1 = nn.Sequential( + fuse_conv_bn(self.branch1[0], self.branch1[1]), + fuse_conv_bn(self.branch1[2], self.branch1[3]), + self.branch1[4] + ) + self.branch2 = nn.Sequential( + fuse_conv_bn(self.branch2[0], self.branch2[1]), + self.branch2[2], + fuse_conv_bn(self.branch2[3], self.branch2[4]), + fuse_conv_bn(self.branch2[5], self.branch2[6]), + self.branch2[7] + ) class ShuffleNetV2(nn.Module): def __init__( @@ -146,6 +159,16 @@ def __init__( self.fc = nn.Linear(output_channels, num_classes) + def switch_to_deploy(self): + self.conv1 = nn.Sequential( + fuse_conv_bn(self.conv1[0], self.conv1[1]), + self.conv1[2] + ) + self.conv5 = nn.Sequential( + fuse_conv_bn(self.conv5[0], self.conv5[1]), + self.conv5[2] + ) + def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: if need_fea: features, features_fc = self.forward_features(x, need_fea) diff --git a/model/vgg.py b/model/vgg.py index 02372a6..5270b38 100644 --- a/model/vgg.py +++ b/model/vgg.py @@ -3,7 +3,7 @@ import numpy as np from torchvision._internally_replaced_utils import load_state_dict_from_url from typing import Union, List, Dict, Any, cast -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn __all__ = [ 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', @@ -88,6 +88,15 @@ def forward_features(self, x, need_fea=False): def cam_layer(self): return self.features[-1] + + def switch_to_deploy(self): + new_features = [] + for i in range(len(self.features)): + if type(self.features[i]) is nn.BatchNorm2d: + new_features[-1] = fuse_conv_bn(new_features[-1], self.features[i]) + else: + new_features.append(self.features[i]) + self.features = nn.Sequential(*new_features) def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: layers: List[nn.Module] = [] diff --git a/model/vovnet.py b/model/vovnet.py index a4f89d0..5714882 100644 --- a/model/vovnet.py +++ b/model/vovnet.py @@ -4,14 +4,14 @@ import numpy as np from torch.hub import load_state_dict_from_url from collections import OrderedDict -from utils.utils import load_weights_from_state_dict +from utils.utils import load_weights_from_state_dict, fuse_conv_bn -__all__ = ['vovnet27_slim', 'vovnet39', 'vovnet57'] +__all__ = ['vovnet39', 'vovnet57'] model_urls = { - 'vovnet39': 'https://dl.dropbox.com/s/1lnzsgnixd8gjra/vovnet39_torchvision.pth?dl=1', - 'vovnet57': 'https://dl.dropbox.com/s/6bfu9gstbwfw31m/vovnet57_torchvision.pth?dl=1' + 'vovnet39': 'https://github.com/z1069614715/pretrained-weights/releases/download/vovnet_v1.0/vovnet39_torchvision.pth', + 'vovnet57': 'https://github.com/z1069614715/pretrained-weights/releases/download/vovnet_v1.0/vovnet57_torchvision.pth' } @@ -90,6 +90,25 @@ def forward(self, x): return xt + def switch_to_deploy(self): + new_features = [] + for i in range(len(self.layers)): + if type(self.layers[i]) is nn.Sequential: + new_features.append(nn.Sequential( + fuse_conv_bn(self.layers[i][0], self.layers[i][1]), + self.layers[i][2] + )) + elif type(self.layers[i]) is nn.BatchNorm2d: + new_features[-1] = fuse_conv_bn(new_features[-1], self.layers[i]) + print(1) + else: + new_features.append(self.layers[i]) + self.layers = nn.Sequential(*new_features) + + self.concat = nn.Sequential( + fuse_conv_bn(self.concat[0], self.concat[1]), + self.concat[2] + ) class _OSA_stage(nn.Sequential): def __init__(self, @@ -163,6 +182,16 @@ def __init__(self, elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) + def switch_to_deploy(self): + self.stem = nn.Sequential( + fuse_conv_bn(self.stem[0], self.stem[1]), + self.stem[2], + fuse_conv_bn(self.stem[3], self.stem[4]), + self.stem[5], + fuse_conv_bn(self.stem[6], self.stem[7]), + self.stem[8], + ) + def forward(self, x, need_fea=False): if need_fea: features, features_fc = self.forward_features(x, need_fea) @@ -205,6 +234,9 @@ def _vovnet(arch, if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + for keys in list(state_dict.keys()): + state_dict[f'{keys.replace("module.", "")}'] = state_dict[keys] + del state_dict[keys] model = load_weights_from_state_dict(model, state_dict) return model diff --git a/predict.py b/predict.py index 920c34f..98ef390 100644 --- a/predict.py +++ b/predict.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np from utils import utils_aug -from utils.utils import predict_single_image, cam_visual, dict_to_PrettyTable, select_device +from utils.utils import predict_single_image, cam_visual, dict_to_PrettyTable, select_device, model_fuse def set_seed(seed): random.seed(seed) @@ -24,7 +24,7 @@ def parse_opt(): parser.add_argument('--cam_visual', action="store_true", help='visual cam') parser.add_argument('--cam_type', type=str, choices=['GradCAM', 'HiResCAM', 'ScoreCAM', 'GradCAMPlusPlus', 'AblationCAM', 'XGradCAM', 'EigenCAM', 'FullGrad'], default='FullGrad', help='cam type') parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference') - parser.add_argument('--device', type=str, default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') opt = parser.parse_known_args()[0] if not os.path.exists(os.path.join(opt.save_path, 'best.pt')): @@ -35,7 +35,9 @@ def parse_opt(): raise Exception('half inference only supported GPU.') if opt.half and opt.cam_visual: raise Exception('cam visual only supported FP32.') - model = (ckpt['model'] if opt.half else ckpt['model'].float()) + model = ckpt['model'].float() + model_fuse(model) + model = (model.half() if opt.half else model) model.to(DEVICE) model.eval() train_opt = ckpt['opt'] diff --git a/processing.py b/processing.py index dd454ce..549a6f2 100644 --- a/processing.py +++ b/processing.py @@ -5,7 +5,6 @@ # set random seed np.random.seed(0) - ''' This file help us to split the dataset. It's going to be a training set, a validation set, a test set. diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..935f6cd Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/utils.cpython-38.pyc b/utils/__pycache__/utils.cpython-38.pyc index 96ced2e..67a8826 100644 Binary files a/utils/__pycache__/utils.cpython-38.pyc and b/utils/__pycache__/utils.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_aug.cpython-38.pyc b/utils/__pycache__/utils_aug.cpython-38.pyc index 49b4b3f..15166b0 100644 Binary files a/utils/__pycache__/utils_aug.cpython-38.pyc and b/utils/__pycache__/utils_aug.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_distill.cpython-38.pyc b/utils/__pycache__/utils_distill.cpython-38.pyc index e48df5b..8cc27bc 100644 Binary files a/utils/__pycache__/utils_distill.cpython-38.pyc and b/utils/__pycache__/utils_distill.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_fit.cpython-38.pyc b/utils/__pycache__/utils_fit.cpython-38.pyc index e811dc8..4d4890d 100644 Binary files a/utils/__pycache__/utils_fit.cpython-38.pyc and b/utils/__pycache__/utils_fit.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_loss.cpython-38.pyc b/utils/__pycache__/utils_loss.cpython-38.pyc index 4184354..040324f 100644 Binary files a/utils/__pycache__/utils_loss.cpython-38.pyc and b/utils/__pycache__/utils_loss.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_model.cpython-38.pyc b/utils/__pycache__/utils_model.cpython-38.pyc index 078909a..59c04d0 100644 Binary files a/utils/__pycache__/utils_model.cpython-38.pyc and b/utils/__pycache__/utils_model.cpython-38.pyc differ diff --git a/utils/utils.py b/utils/utils.py index eb0d71b..d252f28 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -639,11 +639,12 @@ def predict_single_image(path, model, test_transform, DEVICE, half=False): pil_img = Image.open(path) tensor_img = test_transform(pil_img).unsqueeze(0).to(DEVICE) tensor_img = (tensor_img.half() if (half and torch.cuda.is_available()) else tensor_img) - if len(tensor_img.shape) == 5: - tensor_img = tensor_img.reshape((tensor_img.size(0) * tensor_img.size(1), tensor_img.size(2), tensor_img.size(3), tensor_img.size(4))) - output = model(tensor_img).mean(0) - else: - output = model(tensor_img)[0] + with torch.inference_mode(): + if len(tensor_img.shape) == 5: + tensor_img = tensor_img.reshape((tensor_img.size(0) * tensor_img.size(1), tensor_img.size(2), tensor_img.size(3), tensor_img.size(4))) + output = model(tensor_img).mean(0) + else: + output = model(tensor_img)[0] try: pred_result = torch.softmax(output, 0) @@ -766,7 +767,9 @@ def __init__(self, device, opt): if self.opt.model_type == 'torch': ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) - self.model = (ckpt['model'] if opt.half else ckpt['model'].float()) + self.model = ckpt['model'].float() + model_fuse(self.model) + self.model = (self.model.half() if opt.half else self.model) self.model.to(self.device) self.model.eval() elif self.opt.model_type == 'onnx': @@ -812,7 +815,8 @@ def __init__(self, device, opt): def __call__(self, inputs): if self.opt.model_type == 'torch': - return self.model(inputs) + with torch.inference_mode(): + return self.model(inputs) elif self.opt.model_type == 'onnx': inputs = inputs.cpu().numpy().astype(np.float16 if '16' in self.model.get_inputs()[0].type else np.float32) return self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: inputs})[0] @@ -871,3 +875,76 @@ def select_device(device='', batch_size=0): arg = 'cpu' print(print_str) return torch.device(arg) + +def fuse_conv_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + + # prepare spatial bias + b_conv = ( + torch.zeros(conv.weight.size(0), device=conv.weight.device) + if conv.bias is None + else conv.bias + ) + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( + torch.sqrt(bn.running_var + bn.eps) + ) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + return fusedconv + +def model_fuse(model): + before_fuse_layers = len(getLayers(model)) + for module in model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print(f'model fuse... {before_fuse_layers} layers to {len(getLayers(model))} layers') + +def getLayers(model): + """ + get each layer's name and its module + :param model: + :return: each layer's name and its module + """ + layers = [] + + def unfoldLayer(model): + """ + unfold each layer + :param model: the given model or a single layer + :param root: root name + :return: + """ + + # get all layers of the model + layer_list = list(model.named_children()) + for item in layer_list: + module = item[1] + sublayer = list(module.named_children()) + sublayer_num = len(sublayer) + + # if current layer contains sublayers, add current layer name on its sublayers + if sublayer_num == 0: + layers.append(module) + # if current layer contains sublayers, unfold them + elif isinstance(module, torch.nn.Module): + unfoldLayer(module) + + unfoldLayer(model) + return layers \ No newline at end of file diff --git a/utils/utils_model.py b/utils/utils_model.py index f807d85..d9fc7f6 100644 --- a/utils/utils_model.py +++ b/utils/utils_model.py @@ -109,6 +109,12 @@ def select_model(name, num_classes, input_shape, channels, pretrained=False): nn.Dropout(0.2), nn.Linear(in_features=model.classifier.in_features, out_features=num_classes) ) + elif name.startswith('repghostnet'): + model = eval('models.{}(pretrained={})'.format(name, pretrained)) + model.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(in_features=model.classifier.in_features, out_features=num_classes) + ) else: raise 'Unsupported Model Name.' diff --git a/v1.3-update_log.md b/v1.3-update_log.md new file mode 100644 index 0000000..ba3c993 --- /dev/null +++ b/v1.3-update_log.md @@ -0,0 +1,6 @@ +# pytorch-classifier v1.3 更新日志 + +1. 增加[repghost](https://arxiv.org/abs/2211.06088)模型. +2. 推理阶段把模型中的conv和bn进行fuse. +3. 发现mnasnet0_5有点问题,暂停使用. +4. torch.no_grad()更换成torch.inference_mode(). \ No newline at end of file