Skip to content

Commit

Permalink
v1.3
Browse files Browse the repository at this point in the history
  • Loading branch information
z1069614715 committed Dec 28, 2022
1 parent f60dc7f commit 1f99f39
Show file tree
Hide file tree
Showing 51 changed files with 1,102 additions and 75 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ image classifier implement in pytoch.
type: string, default: acc, choices:['loss', 'acc', 'mean_acc']
根据metrice选择的指标来进行保存best.pt.
- **patience**
type: int, default:30
type: int, default:30
早停法中的patience.(设置为0即为不使用早停法)
- **imagenet_meanstd**
default:False
Expand Down Expand Up @@ -288,7 +288,7 @@ image classifier implement in pytoch.
Example: transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=20),])
自定义的数据增强.
- **export.py**
导出模型的文件.目前支持torchscript,onnx.
导出模型的文件.目前支持torchscript,onnx,tensorrt.
参数解释:
- **save_path**
type: string, default: runs/exp
Expand Down Expand Up @@ -334,7 +334,7 @@ image classifier implement in pytoch.
| densenet | densenet121,densenet161,densenet169,densenet201 |
| vgg | vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn |
| efficientnet | efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,efficientnet_b4,efficientnet_b5,efficientnet_b6,efficientnet_b7<br>efficientnet_v2_s,efficientnet_v2_m,efficientnet_v2_l |
| nasnet | mnasnet0_5,mnasnet1_0 |
| nasnet | mnasnet1_0 |
| vovnet | vovnet39,vovnet59 |
| convnext | convnext_tiny,convnext_small,convnext_base,convnext_large,convnext_xlarge |
| ghostnet | ghostnet |
Expand All @@ -343,6 +343,7 @@ image classifier implement in pytoch.
| darknet | darknet53,darknetaa53 |
| cspnet | cspresnet50,cspresnext50,cspdarknet53,cs3darknet_m,cs3darknet_l,cs3darknet_x,cs3darknet_focus_m,cs3darknet_focus_l<br>cs3sedarknet_l,cs3sedarknet_x,cs3edgenet_x,cs3se_edgenet_x |
| dpn | dpn68,dpn68b,dpn92,dpn98,dpn107,dpn131 |
| repghost | repghostnet_0_5x,repghostnet_0_58x,repghostnet_0_8x,repghostnet_1_0x,repghostnet_1_11x<br>repghostnet_1_3x,repghostnet_1_5x,repghostnet_2_0x |

<a id="Someexplanation"></a>

Expand Down Expand Up @@ -610,7 +611,7 @@ image classifier implement in pytoch.
- [ ] Accumulation Gradient
- [ ] Model Ensembling
- [ ] Freeze Training
- [ ] Support Fuse Conv and Bn
- [x] Support Fuse Conv and Bn
- [x] Early Stop

<a id="Reference"></a>
Expand Down
19 changes: 19 additions & 0 deletions commad.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
python main.py --model_name efficientnet_v2_l --config config/config.py --batch_size 32 --Augment AutoAugment --save_path runs/efficientnet_v2_l --device 0 \
--pretrained --amp --warmup --ema --imagenet_meanstd

python main.py --model_name resnext50 --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/resnext50 --device 1 \
--pretrained --amp --warmup --ema --imagenet_meanstd

python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0
python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0 --half

python metrice.py --task fps --save_path runs/resnext50 --batch_size 32 --device 0
python metrice.py --task fps --save_path runs/resnext50 --batch_size 32 --device 0 --half

python export.py --save_path runs/efficientnet_v2_l --export onnx --simplify --batch_size 1
python metrice.py --task fps --save_path runs/efficientnet_v2_l --batch_size 1 --device 0 --model_type onnx

python export.py --save_path runs/resnext50 --export onnx --simplify --batch_size 1
python metrice.py --task fps --save_path runs/resnext50 --batch_size 1 --device 0 --model_type onnx

python predict.py --source dataset/test/0000 --save_path runs/resnext50 --half --device 0
Binary file modified config/__pycache__/config.cpython-38.pyc
Binary file not shown.
3 changes: 2 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
from utils.utils import select_device
from utils.utils import select_device, model_fuse

def export_torchscript(opt, model, img, prefix='TorchScript'):
print('Starting TorchScript export with pytorch %s...' % torch.__version__)
Expand Down Expand Up @@ -104,6 +104,7 @@ def parse_opt():
assert not opt.dynamic, '--half not compatible with --dynamic'
ckpt = torch.load(os.path.join(opt.save_path, 'best.pt'))
model = ckpt['model'].float().to(DEVICE)
model_fuse(model)
img = torch.rand((opt.batch_size, opt.image_channel, opt.image_size, opt.image_size)).to(DEVICE)

return opt, (model.half() if opt.half else model), (img.half() if opt.half else img), DEVICE
Expand Down
7 changes: 4 additions & 3 deletions metrice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
from utils import utils_aug
from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device
from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device, model_fuse

torch.backends.cudnn.deterministic = True
def set_seed(seed):
Expand Down Expand Up @@ -56,7 +56,8 @@ def parse_opt():
fps_arr = []
for i in tqdm.tqdm(range(test_time + warm_up)):
since = time.time()
model(inputs)
with torch.inference_mode():
model(inputs)
if i > warm_up:
fps_arr.append(time.time() - since)
fps = np.mean(fps_arr)
Expand Down Expand Up @@ -85,7 +86,7 @@ def parse_opt():
if __name__ == '__main__':
opt, model, test_dataset, DEVICE, CLASS_NUM, label, save_path = parse_opt()
y_true, y_pred, y_score, y_feature, img_path = [], [], [], [], []
with torch.no_grad():
with torch.inference_mode():
for x, y, path in tqdm.tqdm(test_dataset, desc='Test Stage'):
x = (x.half().to(DEVICE) if opt.half else x.to(DEVICE))
if opt.test_tta:
Expand Down
3 changes: 2 additions & 1 deletion model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
from .repvgg import *
from .sequencer import *
from .cspnet import *
from .dpn import *
from .dpn import *
from .repghost import *
Binary file modified model/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/convnext.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/cspnet.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/densenet.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/dpn.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/efficientnetv2.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/ghostnet.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/mnasnet.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/mobilenetv2.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/mobilenetv3.cpython-38.pyc
Binary file not shown.
Binary file added model/__pycache__/repghost.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/repvgg.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/resnest.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/resnet.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/sequencer.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/shufflenetv2.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/vgg.cpython-38.pyc
Binary file not shown.
Binary file modified model/__pycache__/vovnet.cpython-38.pyc
Binary file not shown.
107 changes: 102 additions & 5 deletions model/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from timm.models.helpers import named_apply
from timm.models.layers import ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
from timm.models.registry import register_model
from utils.utils import load_weights_from_state_dict
from utils.utils import load_weights_from_state_dict, fuse_conv_bn

urls_dict = {
'cspresnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth',
Expand Down Expand Up @@ -343,6 +343,18 @@ def forward(self, x):
x = self.act3(x)
return x

def switch_to_deploy(self):
self.conv1 = nn.Sequential(
fuse_conv_bn(self.conv1.conv, self.conv1.bn),
self.conv1.bn.act,
)
self.conv2 = nn.Sequential(
fuse_conv_bn(self.conv2.conv, self.conv2.bn),
self.conv2.bn.act,
)
self.conv3 = nn.Sequential(
fuse_conv_bn(self.conv3.conv, self.conv3.bn),
)

class DarkBlock(nn.Module):
""" DarkNet Block
Expand Down Expand Up @@ -383,6 +395,15 @@ def forward(self, x):
x = self.drop_path(x) + shortcut
return x

def switch_to_deploy(self):
self.conv1 = nn.Sequential(
fuse_conv_bn(self.conv1.conv, self.conv1.bn),
self.conv1.bn.act,
)
self.conv2 = nn.Sequential(
fuse_conv_bn(self.conv2.conv, self.conv2.bn),
self.conv2.bn.act,
)

class EdgeBlock(nn.Module):
""" EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
Expand Down Expand Up @@ -422,6 +443,16 @@ def forward(self, x):
x = self.conv2(x)
x = self.drop_path(x) + shortcut
return x

def switch_to_deploy(self):
self.conv1 = nn.Sequential(
fuse_conv_bn(self.conv1.conv, self.conv1.bn),
self.conv1.bn.act,
)
self.conv2 = nn.Sequential(
fuse_conv_bn(self.conv2.conv, self.conv2.bn),
self.conv2.bn.act,
)


class CrossStage(nn.Module):
Expand Down Expand Up @@ -500,6 +531,34 @@ def forward(self, x):
out = self.conv_transition(torch.cat([xs, xb], dim=1))
return out

def switch_to_deploy(self):
if type(self.conv_down) is nn.Sequential:
self.conv_down = nn.Sequential(
self.conv_down[0],
fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn),
self.conv_down[1].bn.act,
self.conv_down[1].aa
)
elif type(self.conv_down) is nn.Identity:
pass
else:
self.conv_down = nn.Sequential(
fuse_conv_bn(self.conv_down.conv, self.conv_down.bn),
self.conv_down.bn.act,
self.conv_down.aa
)
self.conv_exp = nn.Sequential(
fuse_conv_bn(self.conv_exp.conv, self.conv_exp.bn),
self.conv_exp.bn.act,
)
self.conv_transition_b = nn.Sequential(
fuse_conv_bn(self.conv_transition_b.conv, self.conv_transition_b.bn),
self.conv_transition_b.bn.act,
)
self.conv_transition = nn.Sequential(
fuse_conv_bn(self.conv_transition.conv, self.conv_transition.bn),
self.conv_transition.bn.act,
)

class CrossStage3(nn.Module):
"""Cross Stage 3.
Expand Down Expand Up @@ -575,6 +634,29 @@ def forward(self, x):
out = self.conv_transition(torch.cat([x1, x2], dim=1))
return out

def switch_to_deploy(self):
if self.conv_down is not None:
if type(self.conv_down) is nn.Sequential:
self.conv_down = nn.Sequential(
self.conv_down[0],
fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn),
self.conv_down[1].bn.act,
self.conv_down[1].aa
)
else:
self.conv_down = nn.Sequential(
fuse_conv_bn(self.conv_down.conv, self.conv_down.bn),
self.conv_down.bn.act,
self.conv_down.aa
)
self.conv_exp = nn.Sequential(
fuse_conv_bn(self.conv_exp.conv, self.conv_exp.bn),
self.conv_exp.bn.act,
)
self.conv_transition = nn.Sequential(
fuse_conv_bn(self.conv_transition.conv, self.conv_transition.bn),
self.conv_transition.bn.act,
)

class DarkStage(nn.Module):
"""DarkNet stage."""
Expand Down Expand Up @@ -630,6 +712,20 @@ def forward(self, x):
x = self.blocks(x)
return x

def switch_to_deploy(self):
if type(self.conv_down) is nn.Sequential:
self.conv_down = nn.Sequential(
self.conv_down[0],
fuse_conv_bn(self.conv_down[1].conv, self.conv_down[1].bn),
self.conv_down[1].bn.act,
self.conv_down[1].aa
)
else:
self.conv_down = nn.Sequential(
fuse_conv_bn(self.conv_down.conv, self.conv_down.bn),
self.conv_down.bn.act,
self.conv_down.aa
)

def create_csp_stem(
in_chans=3,
Expand Down Expand Up @@ -834,11 +930,11 @@ def forward_features(self, x, need_fea=False):
x = layer(x)
features.append(x)
x = self.avgpool(x)
return features, torch.flatten(x)
return features, torch.flatten(x, start_dim=1, end_dim=3)
else:
x = self.stages(x)
x = self.avgpool(x)
return torch.flatten(x)
return torch.flatten(x, start_dim=1, end_dim=3)

def forward_head(self, x):
return self.head(x)
Expand Down Expand Up @@ -992,8 +1088,9 @@ def cs3se_edgenet_x(pretrained=False, **kwargs):
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)

if __name__ == '__main__':
inputs = torch.rand((1, 3, 224, 224))
model = cs3se_edgenet_x(pretrained=False)
inputs = torch.rand((2, 3, 224, 224))
model = cspresnet50(pretrained=False)
model.head = nn.Linear(model.head.in_features, 1)
model.eval()
out = model(inputs)
print('out shape:{}'.format(out.size()))
Expand Down
26 changes: 15 additions & 11 deletions model/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision._internally_replaced_utils import load_state_dict_from_url
from torch import Tensor
from typing import Any, List, Tuple

from utils.utils import load_weights_from_state_dict, fuse_conv_bn

__all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161']

Expand Down Expand Up @@ -246,7 +246,10 @@ def forward_features(self, x, need_fea=False):

def cam_layer(self):
return self.features[-1]


def switch_to_deploy(self):
self.features.conv0 = fuse_conv_bn(self.features.conv0, self.features.norm0)
del self.features.norm0

def load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
Expand All @@ -262,15 +265,16 @@ def load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model_dict = model.state_dict()
weight_dict = {}
for k, v in state_dict.items():
if k in model_dict:
if np.shape(model_dict[k]) == np.shape(v):
weight_dict[k] = v
pretrained_dict = weight_dict
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
load_weights_from_state_dict(model, state_dict)
# model_dict = model.state_dict()
# weight_dict = {}
# for k, v in state_dict.items():
# if k in model_dict:
# if np.shape(model_dict[k]) == np.shape(v):
# weight_dict[k] = v
# pretrained_dict = weight_dict
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)

def _densenet(
arch: str,
Expand Down
1 change: 0 additions & 1 deletion model/dpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=Ba
def forward(self, x):
return self.conv(self.bn(x))


class DualPathBlock(nn.Module):
def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
Expand Down
Loading

0 comments on commit 1f99f39

Please sign in to comment.