Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
liaogulou committed Oct 25, 2023
1 parent f571062 commit 3d37af8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
19 changes: 5 additions & 14 deletions configs/config_templates/yolox_itag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,22 @@
train_path = 'data/coco/train2017.manifest'
val_path = 'data/coco/val2017.manifest'

train_dataset=dict(
train_dataset = dict(
type='DetImagesMixDataset',
data_source=dict(
type='DetSourcePAI',
path=train_path,
classes=CLASSES),
data_source=dict(type='DetSourcePAI', path=train_path, classes=CLASSES),
pipeline=train_pipeline,
dynamic_scale=tuple(img_scale))

val_dataset=dict(
val_dataset = dict(
type='DetImagesMixDataset',
imgs_per_gpu=2,
data_source=dict(
type='DetSourcePAI',
path=val_path,
classes=CLASSES),
data_source=dict(type='DetSourcePAI', path=val_path, classes=CLASSES),
pipeline=test_pipeline,
dynamic_scale=None,
label_padding=False)

data = dict(
imgs_per_gpu=16,
workers_per_gpu=4,
train=train_dataset,
val=val_dataset)
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)

# additional hooks
interval = 10
Expand Down
31 changes: 18 additions & 13 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,27 +354,32 @@ def _export_yolox(model, cfg, filename):
classes=cfg.CLASSES)

json.dump(config, ofile)

if export_type == 'onnx':

with io.open(filename+'.config.json' if filename.endswith('onnx') else filename + '.onnx.config.json', 'w') as ofile:

with io.open(
filename + '.config.json' if filename.endswith('onnx')
else filename + '.onnx.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)

json.dump(config, ofile)

torch.onnx.export(model, # 模型的名称
input.to(device), # 一组实例化输入
filename if filename.endswith('onnx') else filename + '.onnx', # 文件保存路径/名称
export_params=True, # 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
opset_version=12, # ONNX 算子集的版本,当前已更新到15
do_constant_folding=True, # 是否执行常量折叠优化
input_names = ['input'], # 输入模型的张量的名称
output_names = ['output'], # 输出模型的张量的名称
)

torch.onnx.export(
model, # 模型的名称
input.to(device), # 一组实例化输入
filename if filename.endswith('onnx') else filename +
'.onnx', # 文件保存路径/名称
export_params=
True, # 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
opset_version=12, # ONNX 算子集的版本,当前已更新到15
do_constant_folding=True, # 是否执行常量折叠优化
input_names=['input'], # 输入模型的张量的名称
output_names=['output'], # 输出模型的张量的名称
)

if export_type == 'jit':
with io.open(filename + '.jit', 'wb') as ofile:
Expand Down
16 changes: 12 additions & 4 deletions easycv/predictors/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from glob import glob

import numpy as np
import torch, onnxruntime
import onnxruntime
import torch

from easycv.core.visualization import imshow_bboxes
from easycv.datasets.utils import replace_ImageToTensor
Expand All @@ -22,9 +23,11 @@
except Exception:
from .interface import PredictorInterface


# 将张量转化为ndarray格式
def onnx_to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
return tensor.detach().cpu().numpy(
) if tensor.requires_grad else tensor.cpu().numpy()


class DetInputProcessor(InputProcessor):
Expand Down Expand Up @@ -392,7 +395,8 @@ def _build_model(self):
model = torch.jit.load(infile, self.device)
else:
if onnxruntime.get_device() == 'GPU':
model = onnxruntime.InferenceSession(self.model_path, providers=['CUDAExecutionProvider'])
model = onnxruntime.InferenceSession(
self.model_path, providers=['CUDAExecutionProvider'])
else:
model = onnxruntime.InferenceSession(self.model_path)
else:
Expand Down Expand Up @@ -422,7 +426,11 @@ def model_forward(self, inputs):
if self.model_type != 'onnx':
outputs = self.model(inputs['img'])
else:
outputs = self.model.run(None, {self.model.get_inputs()[0].name : onnx_to_numpy(inputs['img'])})[0]
outputs = self.model.run(
None, {
self.model.get_inputs()[0].name:
onnx_to_numpy(inputs['img'])
})[0]
outputs = torch.from_numpy(outputs)
outputs = {'results': outputs} # convert to dict format
else:
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ lmdb
numba
numpy
nuscenes-devkit
onnxruntime-gpu
opencv-python
oss2
packaging
Expand All @@ -33,4 +34,3 @@ transformers
wget
xtcocotools
yacs
onnxruntime-gpu

0 comments on commit 3d37af8

Please sign in to comment.