From b6f09eb054d602e50364021e5c9e1c20db23d743 Mon Sep 17 00:00:00 2001
From: z1069614715 <1069614715@qq.com>
Date: Sun, 27 Nov 2022 00:36:36 +0800
Subject: [PATCH] v1.2
---
.idea/.gitignore | 8 -
.idea/encodings.xml | 9 --
.idea/inspectionProfiles/Project_Default.xml | 6 -
.../inspectionProfiles/profiles_settings.xml | 6 -
.idea/misc.xml | 4 -
.idea/modules.xml | 8 -
.idea/pytorch-classifier.iml | 12 --
.idea/vcs.xml | 6 -
README.md | 92 ++++++++++--
config/__pycache__/config.cpython-38.pyc | Bin 1099 -> 1099 bytes
export.py | 119 +++++++++++++++
main.py | 13 +-
metrice.py | 25 ++--
model/__pycache__/__init__.cpython-38.pyc | Bin 466 -> 466 bytes
model/__pycache__/convnext.cpython-38.pyc | Bin 9327 -> 9327 bytes
model/__pycache__/cspnet.cpython-38.pyc | Bin 25129 -> 25129 bytes
model/__pycache__/densenet.cpython-38.pyc | Bin 12329 -> 12329 bytes
model/__pycache__/dpn.cpython-38.pyc | Bin 9915 -> 9915 bytes
.../__pycache__/efficientnetv2.cpython-38.pyc | Bin 30994 -> 30994 bytes
model/__pycache__/ghostnet.cpython-38.pyc | Bin 6895 -> 6895 bytes
model/__pycache__/mnasnet.cpython-38.pyc | Bin 10030 -> 10030 bytes
model/__pycache__/mobilenetv2.cpython-38.pyc | Bin 6937 -> 6937 bytes
model/__pycache__/mobilenetv3.cpython-38.pyc | Bin 10162 -> 10162 bytes
model/__pycache__/repvgg.cpython-38.pyc | Bin 13001 -> 13001 bytes
model/__pycache__/resnest.cpython-38.pyc | Bin 13035 -> 13035 bytes
model/__pycache__/resnet.cpython-38.pyc | Bin 13285 -> 13285 bytes
model/__pycache__/sequencer.cpython-38.pyc | Bin 15312 -> 15312 bytes
model/__pycache__/shufflenetv2.cpython-38.pyc | Bin 8315 -> 8315 bytes
model/__pycache__/vgg.cpython-38.pyc | Bin 9003 -> 9003 bytes
model/__pycache__/vovnet.cpython-38.pyc | Bin 6791 -> 6791 bytes
model/cspnet.py | 2 +-
predict.py | 10 +-
requirements.txt | 16 +-
utils/__pycache__/utils.cpython-38.pyc | Bin 28307 -> 32957 bytes
utils/__pycache__/utils_aug.cpython-38.pyc | Bin 6328 -> 6328 bytes
.../__pycache__/utils_distill.cpython-38.pyc | Bin 5069 -> 5069 bytes
utils/__pycache__/utils_fit.cpython-38.pyc | Bin 3414 -> 3398 bytes
utils/__pycache__/utils_loss.cpython-38.pyc | Bin 3812 -> 3812 bytes
utils/__pycache__/utils_model.cpython-38.pyc | Bin 2988 -> 2988 bytes
utils/utils.py | 137 +++++++++++++++++-
v1.2-update_log.md | 82 +++++++++++
41 files changed, 454 insertions(+), 101 deletions(-)
delete mode 100644 .idea/.gitignore
delete mode 100644 .idea/encodings.xml
delete mode 100644 .idea/inspectionProfiles/Project_Default.xml
delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml
delete mode 100644 .idea/misc.xml
delete mode 100644 .idea/modules.xml
delete mode 100644 .idea/pytorch-classifier.iml
delete mode 100644 .idea/vcs.xml
create mode 100644 export.py
create mode 100644 v1.2-update_log.md
diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 13566b8..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# Default ignored files
-/shelf/
-/workspace.xml
-# Editor-based HTTP Client requests
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/encodings.xml b/.idea/encodings.xml
deleted file mode 100644
index fa9b8a7..0000000
--- a/.idea/encodings.xml
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index 03d9549..0000000
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 4b2f238..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 9c63b12..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/pytorch-classifier.iml b/.idea/pytorch-classifier.iml
deleted file mode 100644
index ddd9297..0000000
--- a/.idea/pytorch-classifier.iml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 94a25f7..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/README.md b/README.md
index 87a44e2..537a2e1 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,16 @@
image classifier implement in pytoch.
+# Directory
+1. **[Introduction](#Introduction)**
+2. **[How to use](#Howtouse)**
+3. **[Argument Explanation](#ArgumentExplanation)**
+4. **[Model Zoo](#ModelZoo)**
+5. **[Some explanation](#Someexplanation)**
+6. **[TODO](#TODO)**
+7. **[Reference](#Reference)**
+
+
## Introduction
为什么推荐你使用这个代码?
@@ -15,7 +25,7 @@ image classifier implement in pytoch.
7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)
- **丰富的模型库**
- 1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#3)
+ 1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#ModelZoo)
2. 目前支持的模型都是通过作者从github和torchvision整合,因此支持修改、改进模型进行实验,并不是直接调用库创建模型.
- **丰富的训练策略**
@@ -33,6 +43,9 @@ image classifier implement in pytoch.
- **丰富的学习率调整策略**
本程序支持学习率预热,支持预热后的自定义学习率策略.[详细看Some explanation第五点](#1)
+- **支持导出各种常用推理框架模型**
+ 目前支持导出torchscript,onnx,tensorrt推理模型.
+
- **简单的安装过程**
@@ -44,11 +57,15 @@ image classifier implement in pytoch.
1. 大部分可视化数据(混淆矩阵,tsne,每个类别的指标)都会以csv或者log的格式保存到本地,方便后期美工图像.
2. 程序大部分输出信息使用PrettyTable进行美化输出,大大增加可观性.
+
+
## How to use
1. 安装程序所需的[环境](#6).
2. 根据[Some explanation中的第三点](#5)处理好数据集.
+
+
## Argument Explanation
- **main.py**
@@ -66,6 +83,9 @@ image classifier implement in pytoch.
- **config**
type: string, default: config/config.py
配置文件的路径.
+ - **device**
+ type: string, default: ''
+ 使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
- **train_path**
type: string, default: dataset/train
训练集的路径.
@@ -165,6 +185,9 @@ image classifier implement in pytoch.
- **rdrop**
default: False
是否采用R-Drop.(不支持知识蒸馏)
+ - **ema**
+ default: False
+ 是否采用EMA.(不支持知识蒸馏)
- **metrice.py**
实现计算指标的主要程序.
参数解释:
@@ -179,7 +202,10 @@ image classifier implement in pytoch.
测试集的路径.
- **label_path**
type: string, default: dataset/label.txt
- 标签的路径.
+ 标签的路径.
+ - **device**
+ type: string, default: ''
+ 使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
- **task**
type: string, default: test, choices: ['train', 'val', 'test', 'fps']
任务类型.选择fps就是单独计算fps指标,选择train、val、test就是计算其指标.
@@ -222,7 +248,9 @@ image classifier implement in pytoch.
- **cam_type**
type: string, default: GradCAMPlusPlus, choices: ['GradCAM', 'HiResCAM', 'ScoreCAM', 'GradCAMPlusPlus', 'AblationCAM', 'XGradCAM', 'EigenCAM', 'FullGrad']
热力图可视化的类型.
-
+ - **device**
+ type: string, default: ''
+ 使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
- **processing.py**
实现预处理数据集的主要程序.
参数解释:
@@ -238,7 +266,6 @@ image classifier implement in pytoch.
- **test_size**
type: float, default: 0.2
测试集的比例.
-
- **config/config.py**
一些额外的参数配置文件.
参数解释:
@@ -246,26 +273,55 @@ image classifier implement in pytoch.
default: None
Example: lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
自定义的学习率调整器.
-
- **lr_scheduler_params**
default: {'T_max': 10,'eta_min': 1e-6}
Example: lr_scheduler_params = {'step_size': 1,'gamma': 0.95} (此处默认为lr_scheduler = torch.optim.lr_scheduler.StepLR)
自定义的学习率调整器的参数,参数需与lr_scheduler匹配.
-
- **random_seed**
default: 0
随机种子设定值.
-
- **plot_train_batch_count**
default: 5
训练过程可视化数据的生成数量.
-
- **custom_augment**
default: transforms.Compose([])
Example: transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=20),])
自定义的数据增强.
+- **export.py**
+ 导出模型的文件.目前支持torchscript,onnx.
+ 参数解释:
+ - **save_path**
+ type: string, default: runs/exp
+ 保存的模型路径,也是保存转换结果的路径.
+ - **image_size**
+ type: int, default: 224
+ 输入模型的图像尺寸大小.
+ - **image_channel**
+ type:int, default: 3
+ 输入模型的图像通道大小.(目前只支持三通道)
+ - **batch_size**
+ type: int, default: 1
+ 单次测试所选取的样本个数.
+ - **dynamic**
+ default: False
+ onnx中的dynamic参数.
+ - **simplify**
+ default: False
+ onnx中的simplify参数.
+ - **half**
+ default: False
+ FP16模型导出.(仅支持GPU环境导出)
+ - **verbose**
+ default: False
+ 导出tensorrt时是否显示日志.
+ - **export**
+ type: string, default: torchscript choices: ['onnx', 'torchscript', 'tensorrt']
+ 选择导出模型.
+ - **device**
+ type: string, default: torchscript
+ 使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
-
+
## Model Zoo
@@ -288,6 +344,8 @@ image classifier implement in pytoch.
| cspnet | cspresnet50,cspresnext50,cspdarknet53,cs3darknet_m,cs3darknet_l,cs3darknet_x,cs3darknet_focus_m,cs3darknet_focus_l
cs3sedarknet_l,cs3sedarknet_x,cs3edgenet_x,cs3se_edgenet_x |
| dpn | dpn68,dpn68b,dpn92,dpn98,dpn107,dpn131 |
+
+
## Some explanation
1. 关于cpu和gpu的问题.
@@ -526,13 +584,20 @@ image classifier implement in pytoch.
- 17. 关于如何使用albumentations的数据增强问题.
-
+ 17. 关于如何使用albumentations的数据增强问题.
我们可以在[albumentations的github](https://github.com/albumentations-team/albumentations)或者[albumentations的官方网站](https://albumentations.ai/docs/api_reference/augmentations/)中找到自己需要的数据增强的名字,比如[RandomGridShuffle](https://github.com/albumentations-team/albumentations#:~:text=%E2%9C%93-,RandomGridShuffle,-%E2%9C%93)的方法,我们可以在config/config.py中进行创建:
Create_Albumentations_From_Name('RandomGridShuffle')
还有些使用者可能需要修改其默认参数,参数可以在其api文档中找到,我们的函数也是支持修改参数的,比如这个RandomGridShuffle函数有一个grid的参数,具体方法如下:
Create_Albumentations_From_Name('RandomGridShuffle', grid=(3, 3))
不止一个参数的话直接也是在后面加即可,但是需要指定其参数的名字.
+
+ 18. 关于export文件的一些解释.
+ 1. tensorrt建议在ubuntu上使用,并且tensorrt只支持在gpu上导出和推理.
+ 2. FP16仅支持在gpu上导出和推理.
+ 3. FP16模式不能与dynamic模式一并使用.
+ 4. 详细GPU和CPU的推理速度实验请看[v1.2更新日志](v1.2-update_log.md).
+
+
## TODO
- [x] Knowledge Distillation
@@ -540,13 +605,16 @@ image classifier implement in pytoch.
- [x] R-Drop
- [ ] SWA
- [ ] DDP Mode
-- [ ] Export Model(onnx, tensorrt, torchscript)
+- [x] Export Model(onnx, torchscript, TensorRT)
- [ ] C++ Inference Code
- [ ] Accumulation Gradient
- [ ] Model Ensembling
- [ ] Freeze Training
+- [ ] Support Fuse Conv and Bn
- [x] Early Stop
+
+
## Reference
https://github.com/BIGBALLON/CIFAR-ZOO
diff --git a/config/__pycache__/config.cpython-38.pyc b/config/__pycache__/config.cpython-38.pyc
index 0f17312ddf438585b273cd79e0c4722352ee45bd..bcc2bcfe4a5c61e94d0f28343dbb6dbd84b2fa84 100644
GIT binary patch
delta 20
acmX@jahiiWl$V!_0SJOROEz*lumAut`2>0Z
delta 20
acmX@jahiiWl$V!_0SGc=vo~@(umAur-UJ%}
diff --git a/export.py b/export.py
new file mode 100644
index 0000000..38339be
--- /dev/null
+++ b/export.py
@@ -0,0 +1,119 @@
+import os, argparse
+import numpy as np
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+import torch
+import torch.nn as nn
+from utils.utils import select_device
+
+def export_torchscript(opt, model, img, prefix='TorchScript'):
+ print('Starting TorchScript export with pytorch %s...' % torch.__version__)
+ f = os.path.join(opt.save_path, 'best.ts')
+ ts = torch.jit.trace(model, img, strict=False)
+ ts.save(f)
+ print(f'Export TorchScript Model Successfully.\nSave sa {f}')
+
+def export_onnx(opt, model, img, prefix='ONNX'):
+ import onnx
+ f = os.path.join(opt.save_path, 'best.onnx')
+ print('Starting ONNX export with onnx %s...' % onnx.__version__)
+ if opt.dynamic:
+ dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, 'output':{0: 'batch'}}
+ else:
+ dynamic_axes = None
+
+ torch.onnx.export(
+ (model.to('cpu') if opt.dynamic else model),
+ (img.to('cpu') if opt.dynamic else img),
+ f, verbose=False, opset_version=13, input_names=['images'], output_names=['output'], dynamic_axes=dynamic_axes)
+
+ onnx_model = onnx.load(f) # load onnx model
+ onnx.checker.check_model(onnx_model) # check onnx model
+
+ if opt.simplify:
+ try:
+ import onnxsim
+ print('\nStarting to simplify ONNX...')
+ onnx_model, check = onnxsim.simplify(onnx_model)
+ assert check, 'assert check failed'
+ except Exception as e:
+ print(f'Simplifier failure: {e}')
+ onnx.save(onnx_model, f)
+
+ print(f'Export Onnx Model Successfully.\nSave sa {f}')
+
+def export_engine(opt, model, img, workspace=4, prefix='TensorRT'):
+ export_onnx(opt, model, img)
+ onnx_file = os.path.join(opt.save_path, 'best.onnx')
+ assert img.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
+ import tensorrt as trt
+ print('Starting TensorRT export with TensorRT %s...' % trt.__version__)
+ f = os.path.join(opt.save_path, 'best.engine')
+
+ TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if opt.verbose else trt.Logger()
+ builder = trt.Builder(TRT_LOGGER)
+ config = builder.create_builder_config()
+ config.max_workspace_size = workspace * 1 << 30
+
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+ network = builder.create_network(flag)
+ parser = trt.OnnxParser(network, TRT_LOGGER)
+ if not parser.parse_from_file(str(onnx_file)):
+ raise RuntimeError(f'failed to load ONNX file: {onnx_file}')
+
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
+ for inp in inputs:
+ print(f'input {inp.name} with shape {inp.shape} and dtype {inp.dtype}')
+ for out in outputs:
+ print(f'output {out.name} with shape {out.shape} and dtype {out.dtype}')
+
+ if opt.dynamic:
+ if img.shape[0] <= 1:
+ print(f"{prefix} WARNING: --dynamic model requires maximum --batch-size argument")
+ profile = builder.create_optimization_profile()
+ for inp in inputs:
+ profile.set_shape(inp.name, (1, *img.shape[1:]), (max(1, img.shape[0] // 2), *img.shape[1:]), img.shape)
+ config.add_optimization_profile(profile)
+
+ print(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and opt.half else 32} engine in {f}')
+ if builder.platform_has_fast_fp16 and opt.half:
+ config.set_flag(trt.BuilderFlag.FP16)
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
+ t.write(engine.serialize())
+ print(f'Export TensorRT Model Successfully.\nSave sa {f}')
+
+def parse_opt():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log')
+ parser.add_argument('--image_size', type=int, default=224, help='image size')
+ parser.add_argument('--image_channel', type=int, default=3, help='image channel')
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+ parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX batchsize')
+ parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
+ parser.add_argument('--half', action="store_true", help='FP32 to FP16')
+ parser.add_argument('--verbose', action="store_true", help='TensorRT:verbose export log')
+ parser.add_argument('--export', default='torchscript', type=str, choices=['onnx', 'torchscript', 'tensorrt'], help='export type')
+ parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+
+ opt = parser.parse_known_args()[0]
+ if not os.path.exists(os.path.join(opt.save_path, 'best.pt')):
+ raise Exception('best.pt not found. please check your --save_path folder')
+ DEVICE = select_device(opt.device)
+ if opt.half:
+ assert DEVICE.type != 'cpu', '--half only supported with GPU export'
+ assert not opt.dynamic, '--half not compatible with --dynamic'
+ ckpt = torch.load(os.path.join(opt.save_path, 'best.pt'))
+ model = ckpt['model'].float().to(DEVICE)
+ img = torch.rand((opt.batch_size, opt.image_channel, opt.image_size, opt.image_size)).to(DEVICE)
+
+ return opt, (model.half() if opt.half else model), (img.half() if opt.half else img), DEVICE
+
+if __name__ == '__main__':
+ opt, model, img, DEVICE = parse_opt()
+
+ if opt.export == 'onnx':
+ export_onnx(opt, model, img)
+ elif opt.export == 'torchscript':
+ export_torchscript(opt, model, img)
+ elif opt.export == 'tensorrt':
+ export_engine(opt, model, img)
\ No newline at end of file
diff --git a/main.py b/main.py
index d8bf934..27a01b0 100644
--- a/main.py
+++ b/main.py
@@ -12,7 +12,7 @@
from utils.utils_model import select_model
from utils import utils_aug
from utils.utils import save_model, plot_train_batch, WarmUpLR, show_config, setting_optimizer, check_batch_size, \
- plot_log, update_opt, load_weights, get_channels, dict_to_PrettyTable, ModelEMA
+ plot_log, update_opt, load_weights, get_channels, dict_to_PrettyTable, ModelEMA, select_device
from utils.utils_distill import *
from utils.utils_loss import *
@@ -29,6 +29,7 @@ def parse_opt():
parser.add_argument('--pretrained', action="store_true", help='using pretrain weight')
parser.add_argument('--weight', type=str, default='', help='loading weight path')
parser.add_argument('--config', type=str, default='config/config.py', help='config path')
+ parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--train_path', type=str, default=r'dataset/train', help='train data path')
parser.add_argument('--val_path', type=str, default=r'dataset/val', help='val data path')
@@ -76,7 +77,7 @@ def parse_opt():
# Tricks parameters
parser.add_argument('--rdrop', action="store_true", help='using R-Drop')
- parser.add_argument('--ema', action="store_true", help='using EMA(Exponential Moving Average)')
+ parser.add_argument('--ema', action="store_true", help='using EMA(Exponential Moving Average) Reference to YOLOV5')
opt = parser.parse_known_args()[0]
if opt.resume:
@@ -100,7 +101,7 @@ def parse_opt():
show_config(deepcopy(opt))
CLASS_NUM = len(os.listdir(opt.train_path))
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ DEVICE = select_device(opt.device, opt.batch_size)
train_transform, test_transform = utils_aug.get_dataprocessing(torchvision.datasets.ImageFolder(opt.train_path),
opt)
@@ -126,9 +127,7 @@ def parse_opt():
test_dataset = torch.utils.data.DataLoader(test_dataset, max(batch_size // (10 if opt.test_tta else 1), 1),
shuffle=False, num_workers=(0 if opt.test_tta else opt.workers))
scaler = torch.cuda.amp.GradScaler(enabled=(opt.amp if torch.cuda.is_available() else False))
- ema = None
- if opt.ema:
- ema = ModelEMA(model)
+ ema = ModelEMA(model) if opt.ema else None
optimizer = setting_optimizer(opt, model)
lr_scheduler = WarmUpLR(optimizer, opt)
if opt.resume:
@@ -181,7 +180,7 @@ def parse_opt():
elif opt.kd_method == 'AT':
kd_loss = AT().to(DEVICE)
- print('{} begin train on {}!'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), DEVICE))
+ print('{} begin train!'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
for epoch in range(begin_epoch, opt.epoch):
if epoch > (save_epoch + opt.patience) and opt.patience != 0:
print('No Improve from {} to {}, EarlyStopping.'.format(save_epoch + 1, epoch))
diff --git a/metrice.py b/metrice.py
index 5314325..457091a 100644
--- a/metrice.py
+++ b/metrice.py
@@ -6,7 +6,7 @@
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
from utils import utils_aug
-from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable
+from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device
torch.backends.cudnn.deterministic = True
def set_seed(seed):
@@ -21,26 +21,28 @@ def parse_opt():
parser.add_argument('--val_path', type=str, default=r'dataset/val', help='val data path')
parser.add_argument('--test_path', type=str, default=r'dataset/test', help='test data path')
parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label path')
- parser.add_argument('--task', type=str, choices=['train', 'val', 'test', 'fps'], default='val', help='train, val, test, fps')
+ parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+ parser.add_argument('--task', type=str, choices=['train', 'val', 'test', 'fps'], default='test', help='train, val, test, fps')
parser.add_argument('--workers', type=int, default=4, help='dataloader workers')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
- parser.add_argument('--save_path', type=str, default=r'runs/mobilenetv2_ST', help='save path for model and log')
+ parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log')
parser.add_argument('--test_tta', action="store_true", help='using TTA Tricks')
parser.add_argument('--visual', action="store_true", help='visual dataset identification')
parser.add_argument('--tsne', action="store_true", help='visual tsne')
parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference')
+ parser.add_argument('--model_type', type=str, choices=['torch', 'torchscript', 'onnx', 'tensorrt'], default='torch', help='model type(default: torch)')
opt = parser.parse_known_args()[0]
+ DEVICE = select_device(opt.device, opt.batch_size)
+ if opt.half and DEVICE.type == 'cpu':
+ raise Exception('half inference only supported GPU.')
if not os.path.exists(os.path.join(opt.save_path, 'best.pt')):
raise Exception('best.pt not found. please check your --save_path folder')
ckpt = torch.load(os.path.join(opt.save_path, 'best.pt'))
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = (ckpt['model'] if opt.half else ckpt['model'].float())
- model.to(DEVICE)
- model.eval()
train_opt = ckpt['opt']
set_seed(train_opt.random_seed)
+ model = Model_Inference(DEVICE, opt)
print('found checkpoint from {}, model type:{}\n{}'.format(opt.save_path, ckpt['model'].name, dict_to_PrettyTable(ckpt['best_metrice'], 'Best Metrice')))
@@ -48,7 +50,7 @@ def parse_opt():
if opt.task == 'fps':
inputs = torch.rand((opt.batch_size, train_opt.image_channel, train_opt.image_size, train_opt.image_size)).to(DEVICE)
- if opt.half:
+ if opt.half and torch.cuda.is_available():
inputs = inputs.half()
warm_up, test_time = 100, 300
fps_arr = []
@@ -83,7 +85,6 @@ def parse_opt():
if __name__ == '__main__':
opt, model, test_dataset, DEVICE, CLASS_NUM, label, save_path = parse_opt()
y_true, y_pred, y_score, y_feature, img_path = [], [], [], [], []
- model.eval()
with torch.no_grad():
for x, y, path in tqdm.tqdm(test_dataset, desc='Test Stage'):
x = (x.half().to(DEVICE) if opt.half else x.to(DEVICE))
@@ -100,7 +101,11 @@ def parse_opt():
if opt.tsne:
pred_feature = model.forward_features(x)
- pred = torch.softmax(pred, 1)
+ try:
+ pred = torch.softmax(pred, 1)
+ except:
+ pred = torch.softmax(torch.from_numpy(pred), 1) # using torch.softmax will faster than numpy
+
y_true.extend(list(y.cpu().detach().numpy()))
y_pred.extend(list(pred.argmax(-1).cpu().detach().numpy()))
y_score.extend(list(pred.max(-1)[0].cpu().detach().numpy()))
diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc
index abd0181198fb108378c7b4abfaeb81fc78222801..242e622859af91c094728471fe3457ce8a7fe50b 100644
GIT binary patch
delta 22
ccmcb_e2JMil$V!_0SJOROOk6g@*ZUb06#hf;{X5v
delta 22
ccmcb_e2JMil$V!_0SJ`tMb8c=o3>`QAoenIr&Y<_I?c
delta 24
ecmZ2^gmL8&M&3|fUM>b8(5Z+|zQ2)ICJ6vprv~Hz
diff --git a/model/__pycache__/densenet.cpython-38.pyc b/model/__pycache__/densenet.cpython-38.pyc
index 908a023befec1e980525a44e18d0a49d63ebd7da..ee1bd52b7656883b9824931f3b13bc395a972a2d 100644
GIT binary patch
delta 23
dcmZ3Purh%!l$V!_0SJOROOg*;Y~+(R002#b22lV2
delta 23
dcmZ3Purh%!l$V!_0SH3M;*&M4Hu6au002m_1(N^(
diff --git a/model/__pycache__/dpn.cpython-38.pyc b/model/__pycache__/dpn.cpython-38.pyc
index d8ad239dcde99cae19231d9f4fbfb45d1d7d0f5c..eb068ec8ff68927beb02fabc2f1806fd8b19e2fa 100644
GIT binary patch
delta 20
acmdn(yW5vLl$V!_0SJOROEz+EPy+xuPX%@W
delta 20
acmdn(yW5vLl$V!_0SLB~#&6`_pauXu90jic
diff --git a/model/__pycache__/efficientnetv2.cpython-38.pyc b/model/__pycache__/efficientnetv2.cpython-38.pyc
index 90a15f1fb6510fc0a17a0371fad94cd6d48b3fca..eabe70a95d25a85c51d65380fd791d6cb6581901 100644
GIT binary patch
delta 22
ccmbRAiE+{=M($8vUM>b82b8sH=$I$jx2}08fSnAOHXW
diff --git a/model/__pycache__/ghostnet.cpython-38.pyc b/model/__pycache__/ghostnet.cpython-38.pyc
index f6a888e63f7dc0b1bdca91a754898cac2fc0a359..9d5452d4d38b28c5630a798729465f8063b8f8c1 100644
GIT binary patch
delta 20
acmaEF`rec~l$V!_0SJOROEz*pkpcie0|nLq
delta 20
acmaEF`rec~l$V!_0SLmX;x}?XkpcidTm`lO
diff --git a/model/__pycache__/mnasnet.cpython-38.pyc b/model/__pycache__/mnasnet.cpython-38.pyc
index 62bba1b030523979d115def76d9bd28ba5a8d9bb..4919f15538008ed2436498962ad832f585f91f9f 100644
GIT binary patch
delta 23
dcmZ4Ix6Y3*l$V!_0SJOROOnqSZsb!?2LMZo1~&iz
delta 23
dcmZ4Ix6Y3*l$V!_0SMmI#3!>EZ{$-@2LMaU1?d0)
diff --git a/model/__pycache__/mobilenetv2.cpython-38.pyc b/model/__pycache__/mobilenetv2.cpython-38.pyc
index d01776f64dd52068cf39bb36ba216c8b12b67ff7..1ae2b42f8200be5ee3cb4f4403e37c46ef5b92ab 100644
GIT binary patch
delta 22
ccmbPfHq(qZl$V!_0SJOROOly3^72Rn06HE7z5oCK
delta 22
ccmbPfHq(qZl$V!_0SMmI#3w)B$jc)Q07C`_8vpjm5Z
delta 20
acmX?^dNP$el$V!_0SKnl$8Y4`X9NI26b1hP
diff --git a/model/__pycache__/resnest.cpython-38.pyc b/model/__pycache__/resnest.cpython-38.pyc
index 24bc0a189a2d3d434244ba696a6e80a74544783e..cc468eb2f64012915d41ed76bcee31d32044bf0a 100644
GIT binary patch
delta 1128
zcmaKrT}YE*6vusZt>w_dsc)#|M5eu)&bguw`m!jY0x3TVOd4zN>s-!lKKszQaFCK|
zNZ5e_GpdU&ys)T2Sy&_$eL)F97rW~&@WPwuy66869EG@x-_HK$|DN}p^E_KtmsMj&
zK|y|w__U5ko*1VJSvqyM(5YhzEfqaw-85o4$)fbtWY+tHiC?1Mrqe>MGD@kg;?K<6
z1N9Ve4Cn(+02w+fPNqVa3Bl>fvAA2;}YUpfwY0EVV@Fp`}swhER8djHcB7HcmkT%
zPqtALv?Ndr3;{=gZ54bNkD;Hm-W>|50VOkAFJuhJM*(d^YrQ%9D+=J)#S8HJhU=a?$`hO
zIp7&hTeqU*%=2ejkOY@VF;
z)q3qNDXH`Idszi8WP3Az1D^p{1l|E3>9PIW`iAL-_qG@CRsdzdx&<9x
ehxrF8a@w|j(F8SyKP$owA?^!A_$+leZ~X-qC=m<*
delta 1128
zcmaKrT}V@57{`5et>w_dsSCB7$h5QRoGbdEFN+c?kn*Fzq;cGix}0b8-JEWuhL)97
zqz`yuUvv?8VMHB5!=gg^z>AVDy4YQJffwFH*S-HUa1`P$etY(R{^vQ*^Stj?loiF?
zl9Q92AwH8=y)VoYxvW)tnCmexA1&v-WD1R1PO~U|wb+bZ!o&mg+j3UOWoE0^mcPV0
zJE5KdP5?pR6p*qb>IOnU0T2dyfL?$DDi9Tr4BU#iQXpyKN+jU#lY?BQJB4xPr_I(1
z)=n;~NAxFKYnV@)u+|!wUt27$%My7*=>;}GZKcYvAI03f5XJg|KE0wG>-73V{VGpD
z)B8y`YJoNcRt+1xBbzif@H
zpAMGob&kN8uDkZ?*6faeDhIuO9tx)xnb01T?aPv|lT@4ON!Fg?O~?-DHR)WQg(Ya4
zeRqLA?ti8@^5-$MWYJf_2k{vCNgCyFSW*3{*)BuIfP56tH?-cX(?cZagMF8A8txda
z+HEy@LukC!j8%Jsa$MmNeqSg_7T{($b!IrCNubTq;3E=
zfm^^haGQ=wJJ~b3B$cyS8j*%gkR`^g^u%eQ5=UcIJA62_r1`iU^>sCbc#p@QI^Qwm
zBw!YJMbnNQY>4I^hXwyQo*HK$&xu+twm{3y=9I8Nj;bHxuCZ!sk-q9jussI!<4tXz
zJk{k!{Vw^mi`DyC2`*%FJ%0zE33v}I0-x!*>)XbL>6(wud3Z~JB4ERU8$
fR^k!(L
diff --git a/model/__pycache__/resnet.cpython-38.pyc b/model/__pycache__/resnet.cpython-38.pyc
index 786b87cb7ad639b03e37e76141b2032c4c9f4cc0..20bdaea5ff389d67a17b6ce1cbae4230124da7a7 100644
GIT binary patch
delta 23
dcmaEw{xqF0l$V!_0SJOROOiF5Hu7CJ1^`zB2L%8C
delta 23
dcmaEw{xqF0l$V!_0SMkT#wTBK-pF^|7yw{62r&Qv
diff --git a/model/__pycache__/sequencer.cpython-38.pyc b/model/__pycache__/sequencer.cpython-38.pyc
index b2b713e9ffecf9ba15b767305252a3480ac1b94f..7f37a308bf1693d2f812c3703adb34428a354e32 100644
GIT binary patch
delta 20
acmcamexaN@l$V!_0SJOROEz*JwFUq~lm;FE
delta 20
acmcamexaN@l$V!_0SLmH<2Q02wFUq~BL)Wm
diff --git a/model/__pycache__/shufflenetv2.cpython-38.pyc b/model/__pycache__/shufflenetv2.cpython-38.pyc
index c9fe96680b15c31a0166bec9cae9617946903341..9033cb934bf376de6bdbe866e00942176edeed26 100644
GIT binary patch
delta 23
dcmezE@Y{hel$V!_0SJOROOj`*Zsf~V0032!2D$(M
delta 23
dcmezE@Y{hel$V!_0SFeg#3#>F+sK!z003A42KE2|
diff --git a/model/__pycache__/vgg.cpython-38.pyc b/model/__pycache__/vgg.cpython-38.pyc
index a9de89af9fa6b3ae02852b6e25216dc0c5d652c1..e40147aaf73c3113443cd45232849bddf0228d33 100644
GIT binary patch
delta 23
dcmZ4Ow%Uy^l$V!_0SJOROOm&%Zse0y1^`Ir1@`~|
delta 23
dcmZ4Ow%Uy^l$V!_0SHQ3lpK0D_>-lH^k5jeM0-06zrlpK0D`c#_~cTRjeM0-06uR8i2wiq
diff --git a/model/cspnet.py b/model/cspnet.py
index dcd9709..084ae5f 100644
--- a/model/cspnet.py
+++ b/model/cspnet.py
@@ -847,7 +847,7 @@ def forward(self, x, need_fea=False):
if need_fea:
features, features_fc = self.forward_features(x, need_fea=need_fea)
x = self.forward_head(features_fc)
- return features, features_fc, x
+ return features, features_fc, x
else:
x = self.forward_features(x)
x = self.forward_head(x)
diff --git a/predict.py b/predict.py
index f39120a..920c34f 100644
--- a/predict.py
+++ b/predict.py
@@ -7,7 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np
from utils import utils_aug
-from utils.utils import predict_single_image, cam_visual, dict_to_PrettyTable
+from utils.utils import predict_single_image, cam_visual, dict_to_PrettyTable, select_device
def set_seed(seed):
random.seed(seed)
@@ -24,13 +24,17 @@ def parse_opt():
parser.add_argument('--cam_visual', action="store_true", help='visual cam')
parser.add_argument('--cam_type', type=str, choices=['GradCAM', 'HiResCAM', 'ScoreCAM', 'GradCAMPlusPlus', 'AblationCAM', 'XGradCAM', 'EigenCAM', 'FullGrad'], default='FullGrad', help='cam type')
parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference')
+ parser.add_argument('--device', type=str, default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
opt = parser.parse_known_args()[0]
-
if not os.path.exists(os.path.join(opt.save_path, 'best.pt')):
raise Exception('best.pt not found. please check your --save_path folder')
ckpt = torch.load(os.path.join(opt.save_path, 'best.pt'))
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ DEVICE = select_device(opt.device)
+ if opt.half and DEVICE.type == 'cpu':
+ raise Exception('half inference only supported GPU.')
+ if opt.half and opt.cam_visual:
+ raise Exception('cam visual only supported FP32.')
model = (ckpt['model'] if opt.half else ckpt['model'].float())
model.to(DEVICE)
model.eval()
diff --git a/requirements.txt b/requirements.txt
index 53a9071..4b64ad1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,7 @@
+# Pytorch-Classifier requirements
+# Usage: pip install -r requirements.txt
+
+# Base ------------------------------------------------------------------------
opencv-python
grad-cam
timm
@@ -8,4 +12,14 @@ pillow
thop
rfconv
albumentations
-pycm
\ No newline at end of file
+pycm
+
+# Export ----------------------------------------------------------------------
+# onnx # ONNX export
+# onnx-simplifier # ONNX simplifier
+# nvidia-pyindex # TensorRT export
+# nvidia-tensorrt # TensorRT export
+
+# Export Inference ----------------------------------------------------------------
+# onnxruntime # ONNX CPU Inference
+# onnxruntime-gpu # ONNX GPU Inference
\ No newline at end of file
diff --git a/utils/__pycache__/utils.cpython-38.pyc b/utils/__pycache__/utils.cpython-38.pyc
index d6793a359eafa1d3e0c3710e0f3982bf5320917a..96ced2ee72c6ec59e93b12ac7e05bcddfd1bd73b 100644
GIT binary patch
delta 13022
zcmaKS378zkm1aiP*;Rd1s}Hr#F5M~}trH0eApw%m2qYnNQyfXDBD+ZW>%
zL9CClG2qx?cQDx9w#QgA7y}L1fH5{Op3mOJfH@X&j19(sPYlfAS(vrwe=n;Kwefa;
z^)m8B#ETaZFWwcYUlou2mDzZ(p&@4A?|(19A^XptzO`|t_`|JB59ZP_X-G34PsjOL
zUuZ}-2!qRwg{E{bx
z+>l>VSejnS<;J`|UrgIqSd(5;SestUxTgGBg>~t5Tu$WI7tT(f&E@9&hQh}5MwF+>
zmi#$|P3cXHZ_S@u*qq+X;ZP(YsUL-yM`A@&WX{NIobZLhZ%!8}gfY?n2*(R~E+oQ7jExdjU
zuY2TU^4n;1iF{ms2hY9oyYdMdd=|8xwas49d#Lwt?~zl1^kwon`6Kj6%je|_cwR2wl`qPd(Blg9_&@Sxi7qc=
z_$%cr(nEu*38`I2k7Dw8W(M>B5r
z*mw13+ZUPd8@g%z)R7;x-y_05qOo4kAM02n()y8(MSwr+xW4oa9=*4eE2>{pi`NLe
zPT&myFGz&FaNh2*!Qb123ZKkpj@o(kOKS8h0>1!AnTq)i=1V=!WcH%W9%^3H&rRDY
zVtRN>hrVTZ=u2l$7sAvVW_)a3y#OJi@F#sOD_6|9mi0&dvl-VV6KHo@C@7D~Ema$E
zF+Di*&Fa4p9TTsqxS|fKl|Ul_XeUE86QJfQKrMp=psEZ-TzK{6jLH$si3F~8o}u>^y{-KRUU|kWXj`g@ZdC1D(Z9d=NdsK-HdBnGM9CkGc_H(
zCXYJ{xEVBV0<$p8pz#0ok7w&Y&jYVEG4lW8ffu8RWYC(x2=_AmMWL}L^D`T`g~U&z
zq2KA;SDi|2c&%xL6nCLE6Z(zGv5KSO5158Swk9f~@n|sG5REfFi@LGD5K49UZN<;=
z09@mgF({;{2uny4hQ&n5bOZYFzoy{=nM?Rc$>TWGll#I=Cw;x2~`iYdRzZtPEy>XJLgWn@@qhN!BW{Ssvu2<
zj*$wXv$Aax=y+(-FsY#%aD|&&vUG%~k$L%kJIj66{`pAnE?Q1h9bB_8sHGf-w
zcfncWRsH(~$>~vyG2z%xn7CfGa9TU9omW>ZsX`{78?aehsB`qy3s=hxL04zY+LAEYnRh+w>EQ<}4sZgXxnpoJ=UXDW54E
zm6^?H9bLSq3j1%U`2eYy7Z}KuG0!qKOR}!2ZB#zzxLAoyS#?qCvFUiBYgNv*QxUa~
z@cRkS+|m)pRaQnRtZ+W#I9584Em|)lZt1etvQL)WfK7NOBCE$thx)XgAx1%6>L=
za~Q6Lp?;|UCN*yk*#K~-dJ377UPRf>u?%cX7+7Udt5@|asX1ND$)rKT*-YNb7OtjY
zqCCt>(T;yrf_`n)o0go8J^(b1_+$FFOU_#RA2d>fZzPUOD_?RPXl>S{E{$|1c|`}e
zM;}`HvE581s2plkBEZc_;!;7sf1a=f056<#4%k?yGl_Z8zA_A&ZS|LRYFYoD8pRXT
z?J5Gqo9aVpl;Xb<;#wfm!JKokoJ=oXqeqiCG4k%RPl*Ke_UWKts1y2AU9%RfCnTL3
zMN6J;^k-kNGes+t&8k0A%@O_F`dRwQ?Cf(=Vjz6o#;$eUgX%IgKW%zu*Gfggds=v5
z>`h<3HZ$L6*hC_fEfva}wcfg7U6uDTNYr^QJ~vy51)3?8LN3#zXx-Fy0=%HxsB|F#
zQXjPqN!MjddB-AO#-DkG8t|Zj9^0+r&6y6iNU{?6Z^Un|(EqVw$(~MXnIr%|+2CDN
zC#lp=;I9ab?+urTtSY0T8$XsLXx_S5I98|1qn29VwDRR8*%D*gsPhx`XD
zSybIH@}Fz|)f9{Ld+Vl#k5Iq)kgY9x_xku*zCfw#2utjoHq-UK8c}w+q+Av=e;lU%
zmuUE)zG?mD@DFL>ZULe5=G{!aiITdNz-I}JGelh;(reEisuF9ycX>TZ>8M*Gy_M-#
zhpB=rDX&q@t`Yskgga9YvOuZ~Rq)C$p%STJwTw!%^6B^`Cu~Ivc>dO{io{ZB0KI?1
zRPo>S@eR|L(gLY_&@|Pk#&?pp?)OBBolf4)efr)Fi^UrK%7)p??xiZ87*CPqtDRRX
z*PY8|Ts9RpOhwlIik`Reu((ZMw{h12Rx1I8fY0zEjt%e0RsP@sPh>oCRIMg<#y67*
z$7;HyRCZrLbxoIILrIt3J~HE+d&Dy80#}fB&lATzabhx&x=rwq?UQA-`N#U-Hk~8>
zMXx${Vf1dlo2jjHJGP`F`*!c&t4d`Ya&{-9g;D{gy072`_V3ts0Q3p1zC?pWi_oIB
zDxE;%xrF_qe(2ous-LI2y1nF}+GOKdd47TF>j`k262e4s1vNy7ae_2>8H1}Y6N#pb
zl-5LV_KU#6;;|1O;xNb)J)U{AsO}&-2MN@R1%Q)QuNmv?=88QQk4kpm)3Rht
zmdy>k)ghj+dPx6f%YwDzRJg=jTZe_PEvLVj{XNkjt7@1r5fkch8uU%O5^q7oCz)Um?tH
zfYUymk8)d>&`cfPe*PkVU@uhcDH+W?lzmNK
z4mJvbOGc`HpbFwgeV=MWlD)hLyOFEbsi(BJeRh?PHLnEOBQJV1lRbWIMoG1vdOt+q
zTU60hunQ&CZ-Kh4PS{f2McDA}lI(+D>$}S&jt(C-7D_K}7s_rwv@3QqHAOu^Jp!e&
z;|VyO(msKsyMX;2*i@$$-xfU6?Z)q@CtP-U%Z~b;zT?8q)=5k8EU5iI{mg}zHc{#A
z+y9pSez`?GGqPewpDCpNyPX@w2l}<07bG|qdn0QZD8DyK}B$E7?N$egLoCC+&UKA(l~lH>KCgfo%$nJ@A7ZDH_%B3P~BtN-Wct6c1g&`lILN^QCXu|r95
z@L+L_8>H1p*Vj5(^oA^=Jn2uJPP0E@o;vSPytYHs^qBtlLrXhbAtYJKx{Ur&_JuSm
zXJB5^=N+D3UF!yWLqp+VL#?fZha$s9(X50A!)~+^?yWn8w9Ii+{WWT0_{F6d7W@nh
zi&le6!5Su1dxnQaH*kaD3R<2LL-o{WFjB1d-=&yl3`PbLZg8lf5|N@hYz{^b8^zhe
zpnpQ+*I0=RBQHdCmw}-phTBw$4kd<-Q=$^>rsQ#Ck1tDt7J_3%R_>ICQe-$%&8msY
zv<+TeU(vay&$i*`klC&2X)1kdC&8x(umF(hSH}r_hk930!v2@KT?+JL-KtcuAQ$C+
z#cH}2)u}rFAl;{rADLNYN&?I7#a#9%SPeaicIh~s9J0n^@#vVqdabon1S6j
zGClTR(h=_f8wPxzRxup`<64n0>rE4ZV_0+uHG}q-8jSww>sPMSsmm73KO2~eSurf5
z5sH5Fj#Zl~8Z+lyE~8aZ2d&WwViKzxlGxY6l3ZX#~iZRG%Pzqgmur$pKLfgk!g4
z+`FiwT6Q_1d|2zVc^W4)Ym6FCb8N00+}@1Jn#&%>RP@sH`$BL;f6nDQmvmBBidIk8
z^0;#eF`d9T89L*4=zA_dCz?bBoPekGdzUA#;A6*Jv{8HBZl2gQn#A}XvlsOp!trQq
zh~qFQWTuPw`CA{zo>r
z@df2*-zHXASoX@M`_E
zD<{Nx8Z&4H%Q1J@Z03Sxj;kgNk49cS=drJ~b3Mmg=bN-z`ufZqu}Z&wv{gTpaaOTB
zO5%l7zyV(+Q0R?21Qys=sBTPZ)lYT$Xy*YI3X*Xz(OpvHJ6kn1>+rFMks0TUh36lM
z%w*=^N?6%rnPSn-JKxt&AN~4-NS~xveVrzHro_IezmPrqBC<-p6{@EuY*ILuUqTfe
zcn1}S)8sFv>*!j^LL%BxtEd5!Wa9VhIkH1Mt2fHMqrnwBrFZH7A=l60NlfB8o^v>a
z%;)X=%ldcnJj8q}?T(p~n@M}D*^fu0eU`bQ14E}w_+AN&^SSC1%XFaNNac(F*g)eO
zwhT1h(KqEb?dqi3F*lLCQ(4MdmY=AjW)p;Cy#HoYyoO)RRSbq=QMK9B*Q}h0^!R61
z&eDgbC#rOd;WZ&Dv-8%@V(kED1CWc+zjhSpYGo^a&O@MAQ|y8jyU}a(no__@ce{b1
zpd0Fi6|b1RQE3jFu)G13>(r)7KnA86#hCvbiVYj+9j_SG-uO^`#ps2R7C_;xVHkG!
z7PGhuZ$vr>(%+OJY2+IRgZyk747q{c#85MwhR@1yk0=IZq=!s?sCUW)2BO1oC1l-A
zWQXFnh~ft|im~J7Ekd<|M$1rjBFAy!nE4`}-(_~{;E@EW5*=3!G#Ic46wWH45V
z3{9`p4U3yw2ji9a(2PoKXy&kiWvw5-i5f3e>W5~5T7x?q_gamzWv~gqgxguE?@gW(
z+l(t87;HwHIhAHmm;%~!E7kg;dGrcI3##T*8T`+et?mN6FQoTV;&>RiR@pX~s5A^M
zsx))mVi`iMDch&PEuzb@N{gH7UBdI9kG4%{)i6q75X(KZ6!pt0t(6uz^(L$|HXi<$
z7;LLVD^n^hsAzMSSK5ZU$j>?zsx(%nY%*>U*F1!hVXbgy%MLm1Mx)X=v~t+kMr~-b
zR#h71^l3)#Y8jeFLbav>d4g!+&dyp}NdTTzX#iZuH+D1juwovF97U?EtL#=;)#6sJ
zSUkzCnrn=F>xq=O(HfkF}kCC5%;mlwGpSCnwekS`)_
ztez&oej;gTFHm;XK%CjIi#<7b>83h*p&}!6fUow{d>q*WXQ^Px_M#^C+gCP(T1?GxnjA`<#eKVMoYda
z^eF%8ldU2N%Q&APB)B_=0KmUcWMA)xw3Jh6F@8ZM3Ou|hdjya?r2lHHva$0TMp;0r
z7@0j5Q_c|uyUVNAc@f!<5IC3R+zCxQRc-gZcXpJX!~&>q5}*xJ2MO#3NC#Y;8OnO#
zs}g_VIh_q4gVq&XUybvlzSgXv$48HLYp94EbUTq48!CjM
zZu=7k!ffNd3~G()elXGmLqxO^Fty8I1)EWxA|qg>`FMbEAIH_9j8a_(K+MKqs1n-me;6{wH#029X%o3~?4P0bkVvv#*^d(Iat!2XGeclm0!xcvf>vF>e@G%;u5JUf3)7Q*rC}yO~
zh-LL8=8+Qqxu)_!oE1%yZd{V;*T8%rHV@R3Im69DCZBXm1BoP(7fDv)e!J8hU%MXw
zMc@9iJybb#SK9#6F;Z?!+3~a7$GSeD2u@@Zo3e!wP{_
z!{yX7iVR3OpFxbu)-zH?LIk{9{>@98nlzs8K#P;{_Z2>;TCPBix(Xm2^VPR)omlOE
zY%vcUkGl30sNv%S@SYCkh&fLh`t>Aaj~<6BmlZ1hF_Oh$)!fvua(7iKEP6(nrLiit4wKbVXU7g$ca}O_!Xg>4?|K
z;VvHmnm_$v>EJNb;*7Z{3{YTzX$<4DtB;<9I_@(@OnRIWs^&BOu*>7pKOPR{(e@%{
z60{BQSWi4oo6I96mYB#MTy35ic=_?xCove$SP{i>Q^KX1J-HY)AK*9oWyI>})M!`{
zRxLpHGZcLs+>FInCs6CcA4BF6^Byws3rJ)^hfO>K^c+Mj1x4`HOm(&lQT|lo3+Unn
z6=bcd7345|gd6US3`G$m^>PV;j2o*Exd@SgQy+2$+_;QZ3@E<3I|HQjWX!Fn2n?7w
zdNtx%e`kn-IOxtJqZFYsk6^I-L9Iz5QwrtTfbZmDmEftWAT1~xj|bH|6(O7aXC>5b
zplmwtV8BF0u%-<`(;^v&LF>^K!?zm9`zAnyLX|x*2a&-j57|uV*ugra5+g(f!!ML3
z<%vOON+k-|QmN~mf@f{RBRy3
ztFh7xASg+-*Df3Jr4~Mp3*}4K;-oD_Yj(!Hh+RC9t18?IPRiBR`xAnCPHI{i=7OvCS0mnsIt!PxyM2l+kH
zKfGmWL}R(4i?2`>+)lnab`-@eM33}8Ux#rF^9Er9_?m9QStcdzH8^(O&bpT0eJ|;b
zmB1S!3GEz&{s=;AFBGvq1_K??&r^V{CQwNZW4+h}Cz6^)g}M&B!)!i+zEGnNV!Q)*
zK9AQu2&w2@{hYuv)IeN_I?@vCG^BcP0dPETXH?Pu6r30LVyX@j_%`
zC-4A)M+rPmfbXdvrV`m|#TT%PBioRSq@?qQR1qOYS_#ly1ajB_l-gCR^bhj0*092*
z#hA%s8@LIMpT;r|H`{vaSagZD(L~
z_V_#Qb7vesBaGv?j3`JPzqp7a9>x)IL`SVras8m^s>sYZ9(A3)_g0e*S={8*tyk~e
z-@EU=`|f@H^pD)~ZI-t#H#ge_f2S^sM?&ql=9RJ+ZXL5OWrUOvb8#1U$8wBZB{$v9
zQ~FU~zEVK%{>A{M5XL;sOfPI5$h|XN+(#`xv|Kz8%QA|TB4dy;h&Tje#l~P|uu-Cv
z7(RiRYSM|NzuQK?i?I}{sXR4G-|&WTkUBbAZV
z&W(*S!b+Ihd9l&P7-fu6qtqBkosVIw_*cv!x}OwU{%k!#)U
zsk?4eX7N3I?{1gUz<h4*ZJhTq43vzsY%`2GAppl2@M$A1g&d5c~A0lt5?i|^m%
zR_5~u`9sjVfFIxo!FUT>7x9Oiy!;UF2HM5^5&kF7l=Q4hnKMhEi^Z(>W;Jutb!;ivy1^MRsJ4-AKsgR`~&_W
z2>cMJHSWH`75=ol*83>yKJObDVh3eO!J53^fP0#5bW$b`x#X8i2gzFt>PH@g4ng2y
zXxZ7OwoO-!1lMEx;QIo>AI#~WAh!PZ9oF{c*h$rQa0qckK?ysv3<)Emi$qjwszxUI2Len1Be8Ue6wjh$86Al1EBEOq~a5
z!ULb-kU#=45J5i#c?kL=Ky=|jOfQ0b{6sD6oTMf+Lv0p`wxk*MII@tN9qKjGM^8}#
z2uN-gyJY3Ct`rF(k;FNp1_?*ZaM01_hmgQaEzQzZ9*tO}nHU08edP{=K9p_d!z^5G
zFLVCourfLO2X!E8qay#G4!QDN*}y8oayv#qS$*Y09ClU7(dFwG72_yb74;E>jHA6M
zGK-Z11OEdGn|PoESuhmz26KZUkt2_n4`mkcH91Foa=E5EUY7%l@MmQK;!_5I3o*&m!}?trX3J^GNwgz+m{qk!1iv@J4~@3
zhrEfT9v42G9d0M#hnDTzE~1uhdr_wC<821GkI&g7K1*wi>9z+rhzW?xntGXqVp7O#
zNo-dmiFi}AnbpdD!z(>Z)OVj8{vj(og))6j5F9K~3T@Z{AIjq+=2p)GUd_;fe{%$U
zxCA=C?T^6pbRq7?o@!ZFHJr_pYpNFdoy{)qt{PH^p~w!z_3c`F)D-U{<{9~7)#Mp(
z;egL<(ULG3OK-PqAMAxVxBVSX9LNGey@Ek6Gsv0MgNxAqc2+zg3@sMz)G3CDxpGT&
z-N4UK$`=U86&!z%2dYchT=_!vc(zG?SN)hzj0OR6b?HF)$B`9P7;9j{NBdC0ryZOb
z(~L&0HLRBfqgJdOg?ho6?+Rzzp3Z2}_G?L4Cr;tYStD;OYFf6((2`;zVrO)R^W7B=
zi1paJ0l_Eo;HZ-P>DcV8tUan^&^qWe(Oq1y-TJc{2yI17G|N!<7Jma0oRE9Ml@(-R
zzb2XutsT9IOe*|15sp}I$q&L6btJKu1e};ZqQ%sRaSh4|lALj@2>ykMfpdU-SI?5G
zM(+uB!x-%0!}9CVlPCQHDRZ48VGCDd3DblaO`)P=>~-jq*t-<~Y#np0Zd!eq*}=9X9ACQH5tox{I##4Pb|P5>0qQ0?pml-cU$J8c
za`Z;cWxOxHP?{f3!s`CE<_5L~sU1Gha0%qyeA$7+cbtE{vq8ynrmaOH0>g@!g~+&7
z1g?jY=L{RV7FQ9+g~Oxexuu2jU}WTsFzfX1sIQ&Uv{RtT_qlC%?KpwXx{uj@Sh%)W
zdd5!rFcZBfk}#4Dx5^cDQ=+ucUgS=5bhKQ8GEhs~%cEMeKzoX12*}V&v2{5D3?Au)
zQnDh6n5o7!Lw79lBSavfo@U+JDL<>5v$7ofAmw$5;RtXYMHhm9L2w;{%u-Ve83+po
z(Cx^fAmnCc`?Is47f)h&&4x<()VSY`&Jd<9=98W{9!50r+vOeOy9W07h2t_9ImD0U
zmDPMGkC;-EO!ISCr3O4Ty3wuO^Z1j`NueNBY+FOv>rPs
zs!)8nSd1N1RoG22RlQ5zF>&miOIhz`WRM}F7O8kW3`lrP%n`t5L2-R|VA9{*!!Xoz
z`r53fvD3oL4(jnpf=A=c6uCqdIo7$UV~Y2&Z_u$O_c)fg7Das_lT!-|Gs=>zT`Fb=blXEXHFC*Fp)c+@=dO>D>AA#cu4r9e})M`k+O>
zfQFK_YSoqMaxEefO4iy{Yn-OI0}&qt`f#4eEI6(3SID1|QCju&@~P>gmSN;o^6B85
z0Ow+QTPS{~M;t&JG6NZe{Ie($(@Znk6pd(BG!a(~-4fA={%cuVe-*n^URQtlT8c~_
zQ47rNfT=@=aKUn_dh6!_JSi
z=#ec>p`)6RRGb4O+LSJLKI8@%#0JUdW>@FZ`uYSFUx(#CW>-zhu%Z^NXBJFeSIv%i
z{uACbv7%apm?zbfSoJ!t!+^_&7ii?xOf-`$WABa%L@xPynr}f
zt0j@N1KsJv;>VD`DZe^FUspGo~0xPB|`
zMXUQ^>!Uk+OQidimJWiT-zR;4}o_nU%znG`f_*kz|fsxomGkr>AyYQyV%`
zRk-}UyM7WL
zNZ@X6v>A`s*|_(0VMTRQe1HQ?h`W&khGIKJ*9^4#OmPmSy^CWXAvlBJECMtwWloM^
zKgUn9?h$B>kkgkhaUXF~jy=nHO?oeV3fy3A1|c%G`}%b;46KynRxGb20=F}5!W@D$
z7ki8yP}P_oSJgH0ffe2B=zK$5$TT6q*j|iWb^t0ZIK?dSCdy7%`DjBhUo@s(;>IPU
zs*r*~g=eVhI@!H);^K0Y-6s)qxuMUGAfptu$O!)giWM+O5eJSBPtYU&**$huHe;jY
zwAD44YNf{_vpd4op^xCXb{JqCeJ>I)ne!cqUQ~;$Ttq5Z4
z5_lX{z?8=C%65hQKb!(I0vD(+lE{qG|3HUoMiN4I%HE!H8N
zlp>0KnI*z4-65Y}H|{cwoVFhVHS9vD$S^54sd~@c;)#OzlyU~Kv#@m~g3NuI4?lZu5}fs*9E=tqWCxmb&1wU)(@GFxE0SQ&r1XR2
zC?Tk5P_Zykyp|c)m7FwTCQMVO1q!7X2yvSvyJ}3?AXuM>5BOjB
z!;p&4e0a?La@kcQQlq>4d%SyHad(%0r{D5*`CI+F7@Mc-%V!@xn|$XM>@J!Cot;E0EfgyVtdgbp@L6
zGB~ipaR)a}4}F5JLd5sj!ZQKV&VJa&cwlF>#m%j4x(?S9cs2?={G2tz^H?Mh96@Y4
zJpyK;$$}!@Np>mR#sMP*TZ7hu(w967vayT@*!~3zuUs;Jp~Kn-*0i%NI`=WSC?VVm
zcu4k?mOUlp18t^(;i8qz?iHMQS=+GK|WQ+p|bcGwGQQ6{tUjK_|^(e&%<}e^t&gce=vK
zl0UzC$><6sPk}P~(!qEHc9bH>u)@uGo6O!gBRCue!VIprmv1cJNFkmyDMox+QJUOf
zRA&hgQJUrCFtn^BjgHa~90YjK#^Is$wLOt-6R7IP%c~FMA2ycd(|l>Eze0{a`Sr$$
zPM%VQ0x1Av5qYzmy=h|2EhsQ!gpB&b0Qq`&E|imgF|7NSn-(y3t9<{O#s1Mq_oW=C
zE>8`^<^cq#q@z8>BM;Op&J|)dk2W>w0!mNanIh$URH|VOXZxL+MU=q@I4Z3cvQ@Ha
z73p4C30*2d+qx1H2jq5#VuXaD9Pj*iq45H!s+-qo7i|xTuR)rq8
zF7r9Hs>jmWdd+q-488};B{W&vGDSmk1C^oK0mn5>u>eOgJGS#KDo)gkVhmv|MAXFy
z&H}z8Wj=n==EDr}9eJ$fp%m$u;3Y(0i^7d?sj;OYrfVYZ{Psv&
zU^Sx9Hhm6TuOXO>U?l1?7k~}7MIzEt8;?W1c4t><@$0})7>Gkr(w~H1D>YMi<;kux
zcDnnEt`#g*fYgT&@o@w%BKQ(P2mbxn#dl$V!_0SLYX' if self.dynamic else 'not equal to'} max model size {s}"
+ self.binding_addrs['images'] = int(inputs.data_ptr())
+ self.context.execute_v2(list(self.binding_addrs.values()))
+ y = self.bindings['output'].data
+ return y
+
+ def forward_features(self, inputs):
+ try:
+ return self.model.forward_features(inputs)
+ except:
+ raise 'this model is not a torch model.'
+
+ def cam_layer(self):
+ try:
+ return self.model.cam_layer()
+ except:
+ raise 'this model is not a torch model.'
+
+def select_device(device='', batch_size=0):
+ device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
+ cpu = device == 'cpu'
+ if cpu:
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+ elif device:
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
+ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
+
+ print_str = f'Image-Classifier Python-{platform.python_version()} Torch-{torch.__version__} '
+ if not cpu and torch.cuda.is_available():
+ devices = device.split(',') if device else '0'
+ n = len(devices) # device count
+ if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
+ space = ' ' * len(print_str)
+ for i, d in enumerate(devices):
+ p = torch.cuda.get_device_properties(i)
+ print_str += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"
+ arg = 'cuda:0'
+ else:
+ print_str += 'CPU'
+ arg = 'cpu'
+ print(print_str)
+ return torch.device(arg)
diff --git a/v1.2-update_log.md b/v1.2-update_log.md
new file mode 100644
index 0000000..2375652
--- /dev/null
+++ b/v1.2-update_log.md
@@ -0,0 +1,82 @@
+# pytorch-classifier v1.2 更新日志
+
+1. 新增export.py,支持导出(onnx, torchscript, tensorrt)模型.
+2. metrice.py支持onnx,torchscript,tensorrt的推理.
+
+ 此处在predict.py中暂不支持onnx,torchscript,tensorrt的推理的推理,原因是因为predict.py中的热力图可视化没办法在onnx、torchscript、tensorrt中实现,后续单独推理部分会额外写一部分代码.
+ 在metrice.py中,onnx和torchscript和tensorrt的推理也不支持tsne的可视化,那么我在metrice.py中添加onnx,torchscript,tensorrt的推理的目的是为了测试fps和精度.
+ 所以简单来说,使用metrice.py最好还是直接用torch模型,torchscript和onnx和tensorrt的推理的推理模型后续会写一个单独的推理代码.
+3. main.py,metrice.py,predict.py,export.py中增加--device参数,可以指定设备.
+4. 优化程序和修复一些bug.
+
+---
+#### 训练命令:
+ python main.py --model_name efficientnet_v2_s --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/efficientnet_v2_s --device 0 \
+ --pretrained --amp --warmup --ema --imagenet_meanstd
+
+#### GPU 推理速度测试 sh脚本:
+ batch_size=1 # 1 2 4 8 16 32 64
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --model_type torchscript --batch_size $batch_size
+ python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
+ python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --half --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
+ python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --batch_size $batch_size
+ python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --half --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --half --batch_size $batch_size
+
+#### CPU 推理速度测试 sh脚本:
+ python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --dynamic --device cpu
+ batch_size=1
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
+ batch_size=2
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
+ batch_size=4
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
+ batch_size=8
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
+ batch_size=16
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
+ python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
+
+### 各导出模型在cpu和gpu上的fps实验:
+
+实验环境:
+
+| System | CPU | GPU | RAM | Model |
+| :----: | :----: | :----: | :----: | :----: |
+| Ubuntu20.04 | i7-12700KF | RTX-3090 | 32G DDR5 6400 | efficientnet_v2_s |
+
+
+#### GPU
+| model | Torch FP32 FPS | Torch FP16 FPS | TorchScript FP32 FPS| TorchScript FP16 FPS | ONNX FP32 FPS | ONNX FP16 FPS | TensorRT FP32 FPS | TensorRT FP16 FPS |
+| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
+| batch-size 1 | 93.77 | 105.65 | 233.21 | 260.07 | 177.41 | 308.52 | 311.60 | 789.19 |
+| batch-size 2 | 94.32 | 108.35 | 208.53 | 253.83 | 166.23 | 258.98 | 275.93 | 713.71 |
+| batch-size 4 | 95.98 | 108.31 | 171.99 | 255.05 | 130.43 | 190.03 | 212.75 | 573.88 |
+| batch-size 8 | 94.03 | 85.76 | 118.79 | 210.58 | 87.65 | 122.31 | 147.36 | 416.71 |
+| batch-size 16 | 61.93 | 76.25 | 75.45 | 125.05 | 50.33 | 69.01 | 87.25 | 260.94 |
+| batch-size 32 | 34.56 | 58.11 | 41.93 | 72.29 | 26.91 | 34.46 | 48.54 | 151.35 |
+| batch-size 64 | 18.64 | 31.57 | 23.15 | 38.90 | 12.67 | 15.90 | 26.19 | 85.47 |
+
+#### CPU
+| model | Torch FP32 FPS | Torch FP16 FPS | TorchScript FP32 FPS| TorchScript FP16 FPS | ONNX FP32 FPS | ONNX FP16 FPS | TensorRT FP32 FPS | TensorRT FP16 FPS |
+| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
+| batch-size 1 | 27.91 | Not Support | 46.10 | Not Support | 79.27 | Not Support | Not Support | Not Support |
+| batch-size 1 | 25.26 | Not Support | 24.98 | Not Support | 45.62 | Not Support | Not Support | Not Support |
+| batch-size 4 | 14.02 | Not Support | 13.84 | Not Support | 23.90 | Not Support | Not Support | Not Support |
+| batch-size 8 | 7.53 | Not Support | 7.35 | Not Support | 12.01 | Not Support | Not Support | Not Support |
+| batch-size 16 | 3.07 | Not Support | 3.64 | Not Support | 5.72 | Not Support | Not Support | Not Support |
\ No newline at end of file