diff --git a/demo.py b/demo.py index d7ccf87..696ca58 100644 --- a/demo.py +++ b/demo.py @@ -2,7 +2,6 @@ import tools.find_mxnet import mxnet as mx import os -import importlib import sys from detect.detector import Detector from symbol.symbol_factory import get_symbol @@ -52,7 +51,7 @@ def parse_args(): parser.add_argument('--epoch', dest='epoch', help='epoch of trained model', default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', - default=os.path.join(os.getcwd(), 'model', 'ssd_vgg16_reduced_300'), + default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str) parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect', action='store_true', default=False) @@ -112,7 +111,8 @@ def parse_class_names(class_names): network = None if args.deploy_net else args.network class_names = parse_class_names(args.class_names) - detector = get_detector(network, args.prefix, args.epoch, + prefix = args.prefix + args.network + '_' + str(args.data_shape) + detector = get_detector(network, prefix, args.epoch, args.data_shape, (args.mean_r, args.mean_g, args.mean_b), ctx, len(class_names), args.nms_thresh, args.force_nms) diff --git a/deploy.py b/deploy.py index 5f6b8b0..ff909c8 100644 --- a/deploy.py +++ b/deploy.py @@ -9,12 +9,12 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert a trained model to deploy model') - parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', - choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', + help='which network to use') parser.add_argument('--epoch', dest='epoch', help='epoch of trained model', default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', - default=os.path.join(os.getcwd(), 'model', 'ssd_300'), type=str) + default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str) parser.add_argument('--data-shape', dest='data_shape', type=int, default=300, help='data shape') parser.add_argument('--num-class', dest='num_classes', help='number of classes', @@ -33,7 +33,8 @@ def parse_args(): net = get_symbol(args.network).get_symbol(args.network, args.data_shape, num_classes=args.num_classes, nms_thresh=args.nms_thresh, force_suppress=args.force_nms, nms_topk=args.nms_topk) - _, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) + prefix = args.prefix + args.network + '_' + str(args.data_shape) + _, arg_params, aux_params = mx.model.load_checkpoint(prefix, args.epoch) # new name tmp = args.prefix.rsplit('/', 1) save_prefix = '/deploy_'.join(tmp) diff --git a/evaluate.py b/evaluate.py index 3c9ff44..dda3be4 100644 --- a/evaluate.py +++ b/evaluate.py @@ -11,8 +11,8 @@ def parse_args(): default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str) parser.add_argument('--list-path', dest='list_path', help='which list file to use', default="", type=str) - parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', - choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', + help='which network to use') parser.add_argument('--batch-size', dest='batch_size', type=int, default=32, help='evaluation batch size') parser.add_argument('--num-class', dest='num_class', type=int, default=20, @@ -25,7 +25,7 @@ def parse_args(): parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model', default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='load model prefix', - default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str) + default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str) parser.add_argument('--gpus', dest='gpu_id', help='GPU devices to evaluate with', default='0', type=str) parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate, this can be slow', @@ -78,7 +78,7 @@ def parse_args(): network = None if args.deploy_net else args.network evaluate_net(network, args.rec_path, num_class, (args.mean_r, args.mean_g, args.mean_b), args.data_shape, - args.prefix, args.epoch, ctx, batch_size=args.batch_size, + args.prefix + args.network, args.epoch, ctx, batch_size=args.batch_size, path_imglist=args.list_path, nms_thresh=args.nms_thresh, force_nms=args.force_nms, ovp_thresh=args.overlap_thresh, use_difficult=args.use_difficult, class_names=class_names,