Skip to content

Commit

Permalink
split convert_test and inference_test
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 14, 2021
1 parent aac0505 commit 7d49d2a
Showing 1 changed file with 54 additions and 25 deletions.
79 changes: 54 additions & 25 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@
logger = logging.getLogger("mmdet2trt")


def model_test(test_folder,
cfg_path,
checkpoint,
save_folder,
opt_shape_param=None,
max_workspace_size=1 << 25,
device="cuda:0",
score_thr=0.3,
fp16=True,
enable_mask=False):

if not osp.exists(save_folder):
os.mkdir(save_folder)
trt_model_path = osp.join(save_folder, 'trt_model.pth')

def convert_test(cfg_path,
checkpoint,
trt_model_path,
opt_shape_param=None,
max_workspace_size=1 << 25,
device="cuda:0",
fp16=True,
enable_mask=False):
logger.info("creating {} trt model.".format(cfg_path))
trt_model = mmdet2trt(cfg_path,
checkpoint,
Expand All @@ -39,9 +32,15 @@ def model_test(test_folder,
enable_mask=enable_mask)
logger.info("finish, save trt_model in {}".format(trt_model_path))
torch.save(trt_model.state_dict(), trt_model_path)
return trt_model

trt_model = init_detector(trt_model_path)

def inference_test(trt_model,
cfg_path,
device,
test_folder,
save_folder,
score_thr=0.3):
file_list = os.listdir(test_folder)

for file_name in tqdm.tqdm(file_list):
Expand Down Expand Up @@ -74,6 +73,9 @@ def model_test(test_folder,
cv2.imwrite(osp.join(save_folder, file_name), image)


TEST_MODE_DICT = {'convert': 1, 'inference': 1 << 1, 'all': 0b11}


def main():
parser = ArgumentParser()
parser.add_argument('test_folder', help='folder contain test images')
Expand All @@ -82,6 +84,10 @@ def main():
parser.add_argument(
'save_folder',
help='tensorrt model and test images results save folder')
parser.add_argument('--trt_model_path',
default='',
help='save and inference model. '
'default [save_folder]/trt_model.pth')
parser.add_argument(
'--opt_shape_param',
default='[ [ [1,3,800,800], [1,3,800,1344], [1,3,1344,1344] ] ]',
Expand All @@ -102,17 +108,40 @@ def main():
parser.add_argument('--enable_mask',
action='store_true',
help="enable mask output")
parser.add_argument('--test-mode',
default='all',
help='what to do in the test',
choices=['convert', 'inference', 'all'])
args = parser.parse_args()

model_test(args.test_folder,
args.config,
args.checkpoint,
args.save_folder,
opt_shape_param=eval(args.opt_shape_param),
max_workspace_size=args.max_workspace_size,
device=args.device,
score_thr=args.score_thr,
fp16=args.fp16)
trt_model_path = args.trt_model_path
if len(trt_model_path) == 0:
trt_model_path = osp.join(args.save_folder, 'test_model.pth')

if not osp.exists(args.save_folder):
os.mkdir(args.save_folder)

test_mode = TEST_MODE_DICT[args.test_mode]

if test_mode & TEST_MODE_DICT['convert'] > 0:
convert_test(args.config,
args.checkpoint,
trt_model_path,
opt_shape_param=eval(args.opt_shape_param),
max_workspace_size=args.max_workspace_size,
device=args.device,
fp16=args.fp16)
trt_model = init_detector(trt_model_path)
else:
trt_model = init_detector(trt_model_path)

if test_mode & TEST_MODE_DICT['inference'] > 0:
inference_test(trt_model,
args.config,
args.device,
args.test_folder,
args.save_folder,
score_thr=args.score_thr)


if __name__ == '__main__':
Expand Down

0 comments on commit 7d49d2a

Please sign in to comment.