Skip to content

Commit

Permalink
update access
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jun 26, 2017
1 parent b550a8f commit d8a359b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 37 deletions.
57 changes: 38 additions & 19 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@
import importlib
import sys
from detect.detector import Detector
from symbol.symbol_factory import get_symbol

CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')

def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
nms_thresh=0.5, force_nms=True):
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
nms_thresh=0.5, force_nms=True, nms_topk=400):
"""
wrapper for initialize a detector
Expand All @@ -31,31 +26,34 @@ def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
num_class : int
number of classes
nms_thresh : float
non-maximum suppression threshold
force_nms : bool
force suppress different categories
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
if net is not None:
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix + "_" + str(data_shape), epoch, \
data_shape, mean_pixels, ctx=ctx)
net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
force_nms=force_nms, nms_topk=nms_topk)
detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
return detector

def parse_args():
parser = argparse.ArgumentParser(description='Single-shot detection network demo')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
help='which network to use')
parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg',
help='run demo with images, use comma(without space) to seperate multiple images')
help='run demo with images, use comma to seperate multiple images')
parser.add_argument('--dir', dest='dir', nargs='?',
help='demo image directory, optional', type=str)
parser.add_argument('--ext', dest='extension', help='image extension, optional',
type=str, nargs='?')
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
default=os.path.join(os.getcwd(), 'model', 'ssd_vgg16_reduced_300'),
type=str)
parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect',
action='store_true', default=False)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0,
Expand All @@ -78,9 +76,29 @@ def parse_args():
help='show detection time')
parser.add_argument('--deploy', dest='deploy_net', action='store_true', default=False,
help='Load network from json file, rather than from symbol')
parser.add_argument('--class-names', dest='class_names', type=str,
default='aeroplane, bicycle, bird, boat, bottle, bus, \
car, cat, chair, cow, diningtable, dog, horse, motorbike, \
person, pottedplant, sheep, sofa, train, tvmonitor',
help='string of comma separated names, or text filename')
args = parser.parse_args()
return args

def parse_class_names(class_names):
""" parse # classes and class_names if applicable """
if len(class_names) > 0:
if os.path.isfile(class_names):
# try to open it to read class names
with open(class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
else:
class_names = [c.strip() for c in class_names.split(',')]
for name in class_names:
assert len(name) > 0
else:
raise RuntimeError("No valid class_name provided...")
return class_names

if __name__ == '__main__':
args = parse_args()
if args.cpu:
Expand All @@ -93,10 +111,11 @@ def parse_args():
assert len(image_list) > 0, "No valid image specified to detect"

network = None if args.deploy_net else args.network
class_names = parse_class_names(args.class_names)
detector = get_detector(network, args.prefix, args.epoch,
args.data_shape,
(args.mean_r, args.mean_g, args.mean_b),
ctx, args.nms_thresh, args.force_nms)
ctx, len(class_names), args.nms_thresh, args.force_nms)
# run detection
detector.detect_and_visualize(image_list, args.dir, args.extension,
CLASSES, args.thresh, args.show_timer)
class_names, args.thresh, args.show_timer)
11 changes: 8 additions & 3 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import importlib
import sys
from symbol.symbol_factory import get_symbol

def parse_args():
parser = argparse.ArgumentParser(description='Convert a trained model to deploy model')
Expand All @@ -14,20 +15,24 @@ def parse_args():
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd_300'), type=str)
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
help='data shape')
parser.add_argument('--num-class', dest='num_classes', help='number of classes',
default=20, type=int)
parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.5,
help='non-maximum suppression threshold, default 0.5')
parser.add_argument('--force', dest='force_nms', type=bool, default=True,
help='force non-maximum suppression on different class')
parser.add_argument('--topk', dest='nms_topk', type=int, default=400,
help='apply nms only to top k detections based on scores.')
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
net = importlib.import_module("symbol_" + args.network) \
.get_symbol(args.num_classes, args.nms_thresh, args.force_nms)
net = get_symbol(args.network).get_symbol(args.network, args.data_shape,
num_classes=args.num_classes, nms_thresh=args.nms_thresh,
force_suppress=args.force_nms, nms_topk=args.nms_topk)
_, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
# new name
tmp = args.prefix.rsplit('/', 1)
Expand Down
11 changes: 4 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
import sys
from evaluate.evaluate_net import evaluate_net

CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')

def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a network')
parser.add_argument('--rec-path', dest='rec_path', help='which record file to use',
Expand All @@ -23,7 +17,10 @@ def parse_args():
help='evaluation batch size')
parser.add_argument('--num-class', dest='num_class', type=int, default=20,
help='number of classes')
parser.add_argument('--class-names', dest='class_names', type=str, default=",".join(CLASSES),
parser.add_argument('--class-names', dest='class_names', type=str,
default='aeroplane, bicycle, bird, boat, bottle, bus, \
car, cat, chair, cow, diningtable, dog, horse, motorbike, \
person, pottedplant, sheep, sofa, train, tvmonitor',
help='string of comma separated names, or text filename')
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
default=0, type=int)
Expand Down
2 changes: 1 addition & 1 deletion symbol/symbol_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_symbol(network, data_shape, **kwargs):
kwargs : dict
see symbol_builder.get_symbol for more details
"""
if network.stargswith('legacy'):
if network.startswith('legacy'):
return symbol_builder.import_module(network).get_symbol(**kwargs)
config = get_config(network, data_shape, **kwargs).copy()
config.update(kwargs)
Expand Down
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def parse_args():
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--val-list', dest='val_list', help='validation list to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='training batch size')
parser.add_argument('--resume', dest='resume', type=int, default=-1,
Expand All @@ -41,7 +41,7 @@ def parse_args():
help='set image shape')
parser.add_argument('--label-width', dest='label_width', type=int, default=350,
help='force padding label width to sync across train and validation')
parser.add_argument('--lr', dest='learning_rate', type=float, default=0.004,
parser.add_argument('--lr', dest='learning_rate', type=float, default=0.002,
help='learning rate')
parser.add_argument('--momentum', dest='momentum', type=float, default=0.9,
help='momentum')
Expand All @@ -53,7 +53,7 @@ def parse_args():
help='green mean value')
parser.add_argument('--mean-b', dest='mean_b', type=float, default=104,
help='blue mean value')
parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='150, 200',
parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='80, 160',
help='refactor learning rate at specified epochs')
parser.add_argument('--lr-factor', dest='lr_refactor_ratio', type=str, default=0.1,
help='ratio to refactor learning rate')
Expand Down Expand Up @@ -92,9 +92,9 @@ def parse_class_names(args):
num_class = args.num_class
if len(args.class_names) > 0:
if os.path.isfile(args.class_names):
# try to open it to read class names
with open(args.class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
# try to open it to read class names
with open(args.class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
else:
class_names = [c.strip() for c in args.class_names.split(',')]
assert len(class_names) == num_class, str(len(class_names))
Expand Down

0 comments on commit d8a359b

Please sign in to comment.