From 7d49d2aa35b966f67a7e21b4c59dc8924c825548 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 15 Mar 2021 00:18:23 +0800 Subject: [PATCH] split convert_test and inference_test --- tests/model_test.py | 79 +++++++++++++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/tests/model_test.py b/tests/model_test.py index c207625..377a832 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -14,21 +14,14 @@ logger = logging.getLogger("mmdet2trt") -def model_test(test_folder, - cfg_path, - checkpoint, - save_folder, - opt_shape_param=None, - max_workspace_size=1 << 25, - device="cuda:0", - score_thr=0.3, - fp16=True, - enable_mask=False): - - if not osp.exists(save_folder): - os.mkdir(save_folder) - trt_model_path = osp.join(save_folder, 'trt_model.pth') - +def convert_test(cfg_path, + checkpoint, + trt_model_path, + opt_shape_param=None, + max_workspace_size=1 << 25, + device="cuda:0", + fp16=True, + enable_mask=False): logger.info("creating {} trt model.".format(cfg_path)) trt_model = mmdet2trt(cfg_path, checkpoint, @@ -39,9 +32,15 @@ def model_test(test_folder, enable_mask=enable_mask) logger.info("finish, save trt_model in {}".format(trt_model_path)) torch.save(trt_model.state_dict(), trt_model_path) + return trt_model - trt_model = init_detector(trt_model_path) +def inference_test(trt_model, + cfg_path, + device, + test_folder, + save_folder, + score_thr=0.3): file_list = os.listdir(test_folder) for file_name in tqdm.tqdm(file_list): @@ -74,6 +73,9 @@ def model_test(test_folder, cv2.imwrite(osp.join(save_folder, file_name), image) +TEST_MODE_DICT = {'convert': 1, 'inference': 1 << 1, 'all': 0b11} + + def main(): parser = ArgumentParser() parser.add_argument('test_folder', help='folder contain test images') @@ -82,6 +84,10 @@ def main(): parser.add_argument( 'save_folder', help='tensorrt model and test images results save folder') + parser.add_argument('--trt_model_path', + default='', + help='save and inference model. ' + 'default [save_folder]/trt_model.pth') parser.add_argument( '--opt_shape_param', default='[ [ [1,3,800,800], [1,3,800,1344], [1,3,1344,1344] ] ]', @@ -102,17 +108,40 @@ def main(): parser.add_argument('--enable_mask', action='store_true', help="enable mask output") + parser.add_argument('--test-mode', + default='all', + help='what to do in the test', + choices=['convert', 'inference', 'all']) args = parser.parse_args() - model_test(args.test_folder, - args.config, - args.checkpoint, - args.save_folder, - opt_shape_param=eval(args.opt_shape_param), - max_workspace_size=args.max_workspace_size, - device=args.device, - score_thr=args.score_thr, - fp16=args.fp16) + trt_model_path = args.trt_model_path + if len(trt_model_path) == 0: + trt_model_path = osp.join(args.save_folder, 'test_model.pth') + + if not osp.exists(args.save_folder): + os.mkdir(args.save_folder) + + test_mode = TEST_MODE_DICT[args.test_mode] + + if test_mode & TEST_MODE_DICT['convert'] > 0: + convert_test(args.config, + args.checkpoint, + trt_model_path, + opt_shape_param=eval(args.opt_shape_param), + max_workspace_size=args.max_workspace_size, + device=args.device, + fp16=args.fp16) + trt_model = init_detector(trt_model_path) + else: + trt_model = init_detector(trt_model_path) + + if test_mode & TEST_MODE_DICT['inference'] > 0: + inference_test(trt_model, + args.config, + args.device, + args.test_folder, + args.save_folder, + score_thr=args.score_thr) if __name__ == '__main__':