From 6f517aa1818c0ac37c1ddb8ed1858b6d9e49551a Mon Sep 17 00:00:00 2001 From: Joshua Zhang Date: Tue, 28 Mar 2017 17:12:58 -0500 Subject: [PATCH] clean up and ready --- README.md | 41 ++- dataset/iterator.py | 49 ++- dataset/pascal_voc.py | 7 +- dataset/yolo_format.py | 2 +- demo.py | 4 +- deploy.py | 9 +- detect/detector.py | 7 +- evaluate.py | 11 +- evaluate/eval_voc.py | 15 +- evaluate/evaluate_net.py | 43 ++- mxnet | 2 +- operator/multibox_detection-inl.h | 186 ----------- operator/multibox_detection.cc | 182 ----------- operator/multibox_detection.cu | 227 ------------- operator/multibox_prior-inl.h | 203 ------------ operator/multibox_prior.cc | 84 ----- operator/multibox_prior.cu | 98 ------ operator/multibox_target-inl.h | 261 --------------- operator/multibox_target.cc | 292 ----------------- operator/multibox_target.cu | 411 ------------------------ symbol/common.py | 2 +- symbol/symbol_vgg16_ssd_300.py | 6 +- symbol/symbol_vgg16_ssd_512.py | 8 +- tools/caffe_converter/README.md | 5 +- tools/caffe_converter/convert_model.py | 2 +- tools/caffe_converter/convert_symbol.py | 4 +- tools/visualize_net.py | 3 +- train.py | 2 +- train/train_net.py | Bin 9886 -> 9887 bytes 29 files changed, 139 insertions(+), 2027 deletions(-) delete mode 100644 operator/multibox_detection-inl.h delete mode 100644 operator/multibox_detection.cc delete mode 100644 operator/multibox_detection.cu delete mode 100644 operator/multibox_prior-inl.h delete mode 100644 operator/multibox_prior.cc delete mode 100644 operator/multibox_prior.cu delete mode 100644 operator/multibox_target-inl.h delete mode 100644 operator/multibox_target.cc delete mode 100644 operator/multibox_target.cu diff --git a/README.md b/README.md index 96198da..41cf431 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,16 @@ The arXiv paper is available [here](http://arxiv.org/abs/1512.02325). This example is intended for reproducing the nice detector while fully utilize the remarkable traits of MXNet. * The model is fully compatible with caffe version. -* Model converter from caffe is available, I'll release it once I can convert any symbol other than VGG16. +* Model [converter](#convert-caffemodel) from caffe is available now! * The result is almost identical to the original version. However, due to different implementation details, the results might differ slightly. +### What's new +* Update to the latest version according to caffe version, with 5% mAP increase. +* Use C++ record iterator based on back-end multi-thread engine to achieve huge speed up on multi-gpu environments. +* Add symbol for 512x512 input. +* More network symbols under development and test. +* Extra operators are now in `mxnet/src/operator/contrib`, symbols are modified. Please use [Release-v0.2-beta](https://github.com/zhreshold/mxnet-ssd/releases/tag/v0.2-beta) for old models. + ### Demo results ![demo1](https://cloud.githubusercontent.com/assets/3307514/19171057/8e1a0cc4-8be0-11e6-9d8f-088c25353b40.png) ![demo2](https://cloud.githubusercontent.com/assets/3307514/19171063/91ec2792-8be0-11e6-983c-773bd6868fa8.png) @@ -23,7 +30,8 @@ remarkable traits of MXNet. ### mAP | Model | Training data | Test data | mAP | |:-----------------:|:----------------:|:---------:|:----:| -| VGG16_reduced 300x300 | VOC07+12 trainval| VOC07 test| 71.57| +| VGG16_reduced 300x300 | VOC07+12 trainval| VOC07 test| 77.4| +| VGG16_reduced 512x512 | VOC07+12 trainval | VOC07 test| 79.9| ### Speed | Model | GPU | CUDNN | Batch-size | FPS* | @@ -36,11 +44,11 @@ remarkable traits of MXNet. - *Forward time only, data loading and drawing excluded.* ### Getting started -* You will need python modules: `easydict`, `cv2`, `matplotlib` and `numpy`. +* You will need python modules: `cv2`, `matplotlib` and `numpy`. +If you use mxnet-python api, you probably have already got them. You can install them via pip or package manegers, such as `apt-get`: ``` sudo apt-get install python-opencv python-matplotlib python-numpy -sudo pip install easydict ``` * Clone this repo: ``` @@ -54,23 +62,24 @@ git clone --recursive https://github.com/zhreshold/mxnet-ssd.git # git submodule update --recursive --init cd mxnet-ssd/mxnet ``` -* Build MXNet: `cd $REPO_ROOT/mxnet`. Follow the official instructions [here](http://mxnet.io/get_started/setup.html). +* Build MXNet: `cd /path/to/mxnet-ssd/mxnet`. Follow the official instructions [here](http://mxnet.io/get_started/setup.html). ``` # for Ubuntu/Debian cp make/config.mk ./config.mk # modify it if necessary ``` Remember to enable CUDA if you want to be able to train, since CPU training is -insanely slow. Using CUDNN is optional, it's not fully tested but should be fine. +insanely slow. Using CUDNN is optional, but highly recommanded. ### Try the demo -* Download the pretrained model: [`ssd_300_voc_0712.zip`](https://dl.dropboxusercontent.com/u/39265872/ssd_300_voc0712.zip), and extract to `model/` directory. (This model is converted from VGG_VOC0712_SSD_300x300_iter_60000.caffemodel provided by paper author). +* Download the pretrained model: [`ssd_300_voc_0712.zip`](https://dl.dropboxusercontent.com/u/39265872/ssd_300_voc0712.zip), and extract to `model/` directory. * Run ``` # cd /path/to/mxnet-ssd python demo.py # play with examples: python demo.py --epoch 0 --images ./data/demo/dog.jpg --thresh 0.5 +# wait for library to load for the first time ``` * Check `python demo.py --help` for more options. @@ -99,18 +108,26 @@ in the same `VOCdevkit` folder. ln -s /path/to/VOCdevkit /path/to/this_example/data/VOCdevkit ``` Use hard link instead of copy could save us a bit disk space. +* Create packed binary file for faster training: +``` +# cd /path/to/mxnet-ssd +bash tools/prepare_pascal.sh +# or if you are using windows +python tools/prepare_dataset.py --dataset pascal --year 2007,2012 --set trainval --target ./data/train.lst +python $tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target ./data/val.lst --shuffle False +``` * Start training: ``` python train.py ``` -* By default, this example will use `batch-size=32` and `learning_rate=0.002`. +* By default, this example will use `batch-size=32` and `learning_rate=0.004`. You might need to change the parameters a bit if you have different configurations. Check `python train.py --help` for more training options. For example, if you have 4 GPUs, use: ``` # note that a perfect training parameter set is yet to be discovered for multi-gpu -python train.py --gpus 0,1,2,3 --batch-size 128 --lr 0.0005 +python train.py --gpus 0,1,2,3 --batch-size 128 --lr 0.001 ``` -* Memory usage: MXNet is very memory efficient, training on `VGG16_reduced` model with `batch-size` 32 takes around 4684MB without CUDNN. +* Memory usage: MXNet is very memory efficient, training on `VGG16_reduced` model with `batch-size` 32 takes around 4684MB without CUDNN(conv1_x and conv2_x fixed). ### Evalute trained model Again, currently we only support evaluation on PASCAL VOC @@ -125,9 +142,11 @@ Useful when loading python symbol is not available. ``` # cd /path/to/mxnet-ssd python deploy.py --num-class 20 +# then you can run demo with new model without loading python symbol +python demo.py --prefix model/ssd_300_deploy --epoch 0 --deploy ``` -### Convert model from caffe +### Convert caffemodel Converter from caffe is available at `/path/to/mxnet-ssd/tools/caffe_converter` This is specifically modified to handle custom layer in caffe-ssd. Usage: diff --git a/dataset/iterator.py b/dataset/iterator.py index 5df15ca..5cefece 100644 --- a/dataset/iterator.py +++ b/dataset/iterator.py @@ -1,7 +1,6 @@ import mxnet as mx import numpy as np import cv2 -from tools.image_processing import resize, transform from tools.rand_sampler import RandSampler class DetRecordIter(mx.io.DataIter): @@ -39,7 +38,7 @@ class DetRecordIter(mx.io.DataIter): Returns: ---------- - + """ def __init__(self, path_imgrec, batch_size, data_shape, path_imglist="", label_width=-1, label_pad_width=-1, label_pad_value=-1, @@ -149,7 +148,7 @@ def __init__(self, imdb, batch_size, data_shape, \ if isinstance(data_shape, int): data_shape = (data_shape, data_shape) self._data_shape = data_shape - self._mean_pixels = mean_pixels + self._mean_pixels = mx.nd.array(mean_pixels).reshape((3,1,1)) if not rand_samplers: self._rand_samplers = [] else: @@ -203,7 +202,7 @@ def next(self): raise StopIteration def getindex(self): - return self._current / self.batch_size + return self._current // self.batch_size def getpad(self): pad = self._current + self.batch_size - self._size @@ -213,30 +212,28 @@ def _get_batch(self): """ Load data/label from dataset """ - batch_data = [] + batch_data = mx.nd.zeros((self.batch_size, 3, self._data_shape[0], self._data_shape[1])) batch_label = [] for i in range(self.batch_size): if (self._current + i) >= self._size: if not self.is_train: continue # use padding from middle in each epoch - idx = (self._current + i + self._size / 2) % self._size + idx = (self._current + i + self._size // 2) % self._size index = self._index[idx] else: index = self._index[self._current + i] # index = self.debug_index im_path = self._imdb.image_path_from_index(index) - img = cv2.imread(im_path) + with open(im_path, 'rb') as fp: + img_content = fp.read() + img = mx.img.imdecode(img_content) gt = self._imdb.label_from_index(index).copy() if self.is_train else None data, label = self._data_augmentation(img, gt) - batch_data.append(data) + batch_data[i] = data if self.is_train: batch_label.append(label) - # pad data if not fully occupied - for i in range(self.batch_size - len(batch_data)): - assert len(batch_data) > 0 - batch_data.append(batch_data[0] * 0) - self._data = {'data': mx.nd.array(np.array(batch_data))} + self._data = {'data': batch_data} if self.is_train: self._label = {'label': mx.nd.array(np.array(batch_label))} else: @@ -262,7 +259,7 @@ def _data_augmentation(self, data, label): xmax = int(crop[2] * width) ymax = int(crop[3] * height) if xmin >= 0 and ymin >= 0 and xmax <= width and ymax <= height: - data = data[ymin:ymax, xmin:xmax, :] + data = mx.img.fixed_crop(data, xmin, ymin, xmax-xmin, ymax-ymin) else: # padding mode new_width = xmax - xmin @@ -270,24 +267,24 @@ def _data_augmentation(self, data, label): offset_x = 0 - xmin offset_y = 0 - ymin data_bak = data - data = np.full((new_height, new_width, 3), 128.) + data = mx.nd.full((new_height, new_width, 3), 128, dtype='uint8') data[offset_y:offset_y+height, offset_x:offset_x + width, :] = data_bak label = rand_crops[index][1] - - if self.is_train and self._rand_mirror: - if np.random.uniform(0, 1) > 0.5: - data = cv2.flip(data, 1) - valid_mask = np.where(label[:, 0] > -1)[0] - tmp = 1.0 - label[valid_mask, 1] - label[valid_mask, 1] = 1.0 - label[valid_mask, 3] - label[valid_mask, 3] = tmp - if self.is_train: interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, \ cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] else: interp_methods = [cv2.INTER_LINEAR] interp_method = interp_methods[int(np.random.uniform(0, 1) * len(interp_methods))] - data = resize(data, self._data_shape, interp_method) - data = transform(data, self._mean_pixels) + data = mx.img.imresize(data, self._data_shape[1], self._data_shape[0], interp_method) + if self.is_train and self._rand_mirror: + if np.random.uniform(0, 1) > 0.5: + data = mx.nd.flip(data, axis=1) + valid_mask = np.where(label[:, 0] > -1)[0] + tmp = 1.0 - label[valid_mask, 1] + label[valid_mask, 1] = 1.0 - label[valid_mask, 3] + label[valid_mask, 3] = tmp + data = mx.nd.transpose(data, (2,0,1)) + data = data.astype('float32') + data = data - self._mean_pixels return data, label diff --git a/dataset/pascal_voc.py b/dataset/pascal_voc.py index 56faa11..2c61be7 100644 --- a/dataset/pascal_voc.py +++ b/dataset/pascal_voc.py @@ -1,3 +1,4 @@ +from __future__ import print_function import os import numpy as np from imdb import Imdb @@ -128,7 +129,7 @@ def _label_path_from_index(self, index): full path of annotation file """ label_file = os.path.join(self.data_path, 'Annotations', index + '.xml') - assert os.path.exists(label_file), 'Path does not exist: {}'.format(image_file) + assert os.path.exists(label_file), 'Path does not exist: {}'.format(label_file) return label_file def _load_image_labels(self): @@ -220,7 +221,7 @@ def write_pascal_results(self, all_boxes): None """ for cls_ind, cls in enumerate(self.classes): - print 'Writing {} VOC results file'.format(cls) + print('Writing {} VOC results file'.format(cls)) filename = self.get_result_file_template().format(cls) with open(filename, 'wt') as f: for im_ind, index in enumerate(self.image_set_index): @@ -250,7 +251,7 @@ def do_python_eval(self): aps = [] # The PASCAL VOC metric changed in 2010 use_07_metric = True if int(self.year) < 2010 else False - print 'VOC07 metric? ' + ('Y' if use_07_metric else 'No') + print('VOC07 metric? ' + ('Y' if use_07_metric else 'No')) for cls_ind, cls in enumerate(self.classes): filename = self.get_result_file_template().format(cls) rec, prec, ap = voc_eval(filename, annopath, imageset_file, cls, cache_dir, diff --git a/dataset/yolo_format.py b/dataset/yolo_format.py index 2ac7fbb..e82e5ca 100644 --- a/dataset/yolo_format.py +++ b/dataset/yolo_format.py @@ -36,7 +36,7 @@ def __init__(self, name, classes, list_file, image_dir, label_dir, \ classes = [l.strip() for l in f.readlines()] num_classes = len(classes) else: - raise ValueError, "classes should be list/tuple or text file" + raise ValueError("classes should be list/tuple or text file") assert num_classes > 0, "number of classes must > 0" super(YoloFormat, self).__init__(name + '_' + str(num_classes)) self.classes = classes diff --git a/demo.py b/demo.py index 64d0a3e..ededbdb 100644 --- a/demo.py +++ b/demo.py @@ -44,8 +44,8 @@ def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, def parse_args(): parser = argparse.ArgumentParser(description='Single-shot detection network demo') - parser.add_argument('--network', dest='network', type=str, default='ssd_300', - choices=['ssd_300'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', + choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg', help='run demo with images, use comma(without space) to seperate multiple images') parser.add_argument('--dir', dest='dir', nargs='?', diff --git a/deploy.py b/deploy.py index 28f73fe..264314a 100644 --- a/deploy.py +++ b/deploy.py @@ -1,3 +1,4 @@ +from __future__ import print_function import argparse import tools.find_mxnet import mxnet as mx @@ -7,8 +8,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert a trained model to deploy model') - parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', - choices=['vgg16_reduced'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', + choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') parser.add_argument('--epoch', dest='epoch', help='epoch of trained model', default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', @@ -32,5 +33,5 @@ def parse_args(): tmp = args.prefix.rsplit('/', 1) save_prefix = '/deploy_'.join(tmp) mx.model.save_checkpoint(save_prefix, args.epoch, net, arg_params, aux_params) - print "Saved model: {}-{:04d}.param".format(save_prefix, args.epoch) - print "Saved symbol: {}-symbol.json".format(save_prefix) + print("Saved model: {}-{:04d}.param".format(save_prefix, args.epoch)) + print("Saved symbol: {}-symbol.json".format(save_prefix)) diff --git a/detect/detector.py b/detect/detector.py index 1ea6fa7..19b78f6 100644 --- a/detect/detector.py +++ b/detect/detector.py @@ -1,3 +1,4 @@ +from __future__ import print_function import mxnet as mx import numpy as np from timeit import default_timer as timer @@ -33,7 +34,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \ load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) if symbol is None: symbol = load_symbol - self.mod = mx.mod.Module(symbol, context=ctx) + self.mod = mx.mod.Module(symbol, label_names=None, context=ctx) self.data_shape = data_shape self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))]) self.mod.set_params(args, auxs) @@ -62,8 +63,8 @@ def detect(self, det_iter, show_timer=False): detections = self.mod.predict(det_iter).asnumpy() time_elapsed = timer() - start if show_timer: - print "Detection time for {} images: {:.4f} sec".format( - num_images, time_elapsed) + print("Detection time for {} images: {:.4f} sec".format( + num_images, time_elapsed)) result = [] for i in range(detections.shape[0]): det = detections[i, :, :] diff --git a/evaluate.py b/evaluate.py index 93258c3..a38a7f6 100644 --- a/evaluate.py +++ b/evaluate.py @@ -17,8 +17,8 @@ def parse_args(): default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str) parser.add_argument('--list-path', dest='list_path', help='which list file to use', default="", type=str) - parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', - choices=['vgg16_reduced', 'ssd_300'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', + choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') parser.add_argument('--batch-size', dest='batch_size', type=int, default=32, help='evaluation batch size') parser.add_argument('--num-class', dest='num_class', type=int, default=20, @@ -31,7 +31,7 @@ def parse_args(): default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str) parser.add_argument('--gpus', dest='gpu_id', help='GPU devices to evaluate with', default='0', type=str) - parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate', + parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate, this can be slow', action='store_true') parser.add_argument('--data-shape', dest='data_shape', type=int, default=300, help='set image shape') @@ -78,11 +78,6 @@ def parse_args(): else: class_names = None - # evaluate_net(args.network, args.dataset, args.devkit_path, - # (args.mean_r, args.mean_g, args.mean_b), args.data_shape, - # args.prefix, args.epoch, ctx, year=args.year, - # sets=args.eval_set, batch_size=args.batch_size, - # nms_thresh=args.nms_thresh, force_nms=args.force_nms) network = None if args.deploy_net else args.network evaluate_net(network, args.rec_path, num_class, (args.mean_r, args.mean_g, args.mean_b), args.data_shape, diff --git a/evaluate/eval_voc.py b/evaluate/eval_voc.py index 8975b61..f8f92e1 100644 --- a/evaluate/eval_voc.py +++ b/evaluate/eval_voc.py @@ -1,10 +1,13 @@ """ given a pascal voc imdb, compute mAP """ - +from __future__ import print_function import numpy as np import os -import cPickle +try: + import cPickle as pickle +except ImportError: + import pickle def parse_voc_rec(filename): @@ -88,13 +91,13 @@ def voc_eval(detpath, annopath, imageset_file, classname, cache_dir, ovthresh=0. for ind, image_filename in enumerate(image_filenames): recs[image_filename] = parse_voc_rec(annopath.format(image_filename)) if ind % 100 == 0: - print 'reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames)) - print 'saving annotations cache to {:s}'.format(cache_file) + print('reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames))) + print('saving annotations cache to {:s}'.format(cache_file)) with open(cache_file, 'w') as f: - cPickle.dump(recs, f) + pickle.dump(recs, f) else: with open(cache_file, 'r') as f: - recs = cPickle.load(f) + recs = pickle.load(f) # extract objects in :param classname: class_recs = {} diff --git a/evaluate/evaluate_net.py b/evaluate/evaluate_net.py index 1872883..8d86f8e 100644 --- a/evaluate/evaluate_net.py +++ b/evaluate/evaluate_net.py @@ -14,7 +14,42 @@ def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False): """ + evalute network given validation record file + Parameters: + ---------- + net : str or None + Network name or use None to load from json without modifying + path_imgrec : str + path to the record validation file + path_imglist : str + path to the list file to replace labels in record file, optional + num_classes : int + number of classes, not including background + mean_pixels : tuple + (mean_r, mean_g, mean_b) + data_shape : tuple or int + (3, height, width) or height/width + model_prefix : str + model prefix of saved checkpoint + epoch : int + load model epoch + ctx : mx.ctx + mx.gpu() or mx.cpu() + batch_size : int + validation batch size + nms_thresh : float + non-maximum suppression threshold + force_nms : boolean + whether suppress different class objects + ovp_thresh : float + AP overlap threshold for true/false postives + use_difficult : boolean + whether to use difficult objects in evaluation if applicable + class_names : comma separated str + class names in string, must correspond to num_classes if set + voc07_metric : boolean + whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() @@ -39,11 +74,13 @@ def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) net = importlib.import_module("symbol_" + net) \ .get_symbol(num_classes, nms_thresh, force_nms) - label = mx.sym.Variable(name='label') - net = mx.sym.Group([net, label]) + if not 'label' in net.list_arguments(): + label = mx.sym.Variable(name='label') + net = mx.sym.Group([net, label]) # init module - mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx) + mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, + fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) diff --git a/mxnet b/mxnet index 707b120..23d960d 160000 --- a/mxnet +++ b/mxnet @@ -1 +1 @@ -Subproject commit 707b1203f834469f32209ef9d2d18c92f6c3fa92 +Subproject commit 23d960d50621f7d3abcd04d9a3247822cbd8e9c4 diff --git a/operator/multibox_detection-inl.h b/operator/multibox_detection-inl.h deleted file mode 100644 index fe8d17a..0000000 --- a/operator/multibox_detection-inl.h +++ /dev/null @@ -1,186 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection-inl.h - * \brief post-process multibox detection predictions - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" - -namespace mxnet { -namespace op { -namespace mboxdet_enum { -enum MultiBoxDetectionOpInputs {kClsProb, kLocPred, kAnchor}; -enum MultiBoxDetectionOpOutputs {kOut}; -enum MultiBoxDetectionOpResource {kTempSpace}; -} // namespace mboxdet_enum - -struct MultiBoxDetectionParam : public dmlc::Parameter { - bool clip; - float threshold; - int background_id; - float nms_threshold; - bool force_suppress; - int keep_topk; - int nms_topk; - nnvm::Tuple variances; - DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { - DMLC_DECLARE_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - DMLC_DECLARE_FIELD(threshold).set_default(0.01f) - .describe("Threshold to be a positive prediction."); - DMLC_DECLARE_FIELD(background_id).set_default(0) - .describe("Background id."); - DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5f) - .describe("Non-maximum suppression threshold."); - DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); - DMLC_DECLARE_FIELD(variances).set_default({0.1f, 0.1f, 0.2f, 0.2f}) - .describe("Variances to be decoded from box regression output."); - DMLC_DECLARE_FIELD(nms_topk).set_default(-1) - .describe("Keep maximum top k detections before nms, -1 for no limit."); - } -}; // struct MultiBoxDetectionParam - -template -class MultiBoxDetectionOp : public Operator { - public: - explicit MultiBoxDetectionOp(MultiBoxDetectionParam param) { - this->param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3) << "Input: [cls_prob, loc_pred, anchor]"; - TShape ashape = in_data[mboxdet_enum::kAnchor].shape_; - CHECK_EQ(out_data.size(), 1); - - Stream *s = ctx.get_stream(); - Tensor cls_prob = in_data[mboxdet_enum::kClsProb] - .get(s); - Tensor loc_pred = in_data[mboxdet_enum::kLocPred] - .get(s); - Tensor anchors = in_data[mboxdet_enum::kAnchor] - .get_with_shape(Shape2(ashape[1], 4), s); - Tensor out = out_data[mboxdet_enum::kOut] - .get(s); - Tensor temp_space = ctx.requested[mboxdet_enum::kTempSpace] - .get_space_typed(out.shape_, s); - out = -1.f; - MultiBoxDetectionForward(out, cls_prob, loc_pred, anchors, temp_space, - param_.threshold, param_.clip, param_.variances, param_.nms_threshold, - param_.force_suppress, param_.nms_topk); - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor gradc = in_grad[mboxdet_enum::kClsProb].FlatTo2D(s); - Tensor gradl = in_grad[mboxdet_enum::kLocPred].FlatTo2D(s); - Tensor grada = in_grad[mboxdet_enum::kAnchor].FlatTo2D(s); - gradc = 0.f; - gradl = 0.f; - grada = 0.f; -} - - private: - MultiBoxDetectionParam param_; -}; // class MultiBoxDetectionOp - -template -Operator *CreateOp(MultiBoxDetectionParam, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxDetectionProp : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - std::vector ListArguments() const override { - return {"cls_prob", "loc_pred", "anchor"}; - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 3) << "Inputs: [cls_prob, loc_pred, anchor]"; - TShape cshape = in_shape->at(mboxdet_enum::kClsProb); - TShape lshape = in_shape->at(mboxdet_enum::kLocPred); - TShape ashape = in_shape->at(mboxdet_enum::kAnchor); - CHECK_EQ(cshape.ndim(), 3) << "Provided: " << cshape; - CHECK_EQ(lshape.ndim(), 2) << "Provided: " << lshape; - CHECK_EQ(ashape.ndim(), 3) << "Provided: " << ashape; - CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; - CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; - CHECK_GT(ashape[1], 0) << "Number of anchors must > 0"; - CHECK_EQ(ashape[2], 4); - TShape oshape = TShape(3); - oshape[0] = cshape[0]; - oshape[1] = ashape[1]; - oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] - out_shape->clear(); - out_shape->push_back(oshape); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new MultiBoxDetectionProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "MultiBoxDetection"; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxDetectionParam param_; -}; // class MultiBoxDetectionProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ diff --git a/operator/multibox_detection.cc b/operator/multibox_detection.cc deleted file mode 100644 index 1cddfcf..0000000 --- a/operator/multibox_detection.cc +++ /dev/null @@ -1,182 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection.cc - * \brief MultiBoxDetection op - * \author Joshua Zhang -*/ -#include "./multibox_detection-inl.h" -#include - -namespace mshadow { -template -struct SortElemDescend { - DType value; - int index; - - SortElemDescend(DType v, int i) { - value = v; - index = i; - } - - bool operator<(const SortElemDescend &other) const { - return value > other.value; - } -}; - -template -inline void TransformLocations(DType *out, const DType *anchors, - const DType *loc_pred, const bool clip, - const float vx, const float vy, - const float vw, const float vh) { - // transform predictions to detection results - DType al = anchors[0]; - DType at = anchors[1]; - DType ar = anchors[2]; - DType ab = anchors[3]; - DType aw = ar - al; - DType ah = ab - at; - DType ax = (al + ar) / 2.f; - DType ay = (at + ab) / 2.f; - DType px = loc_pred[0]; - DType py = loc_pred[1]; - DType pw = loc_pred[2]; - DType ph = loc_pred[3]; - DType ox = px * vx * aw + ax; - DType oy = py * vy * ah + ay; - DType ow = exp(pw * vw) * aw / 2; - DType oh = exp(ph * vh) * ah / 2; - out[0] = clip ? std::max(DType(0), std::min(DType(1), ox - ow)) : (ox - ow); - out[1] = clip ? std::max(DType(0), std::min(DType(1), oy - oh)) : (oy - oh); - out[2] = clip ? std::max(DType(0), std::min(DType(1), ox + ow)) : (ox + ow); - out[3] = clip ? std::max(DType(0), std::min(DType(1), oy + oh)) : (oy + oh); -} - -template -inline DType CalculateOverlap(const DType *a, const DType *b) { - DType w = std::max(DType(0), std::min(a[2], b[2]) - std::max(a[0], b[0])); - DType h = std::max(DType(0), std::min(a[3], b[3]) - std::max(a[1], b[1])); - DType i = w * h; - DType u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; - return u <= 0.f ? static_cast(0) : static_cast(i / u); -} - -template -inline void MultiBoxDetectionForward(const Tensor &out, - const Tensor &cls_prob, - const Tensor &loc_pred, - const Tensor &anchors, - const Tensor &temp_space, - const float threshold, - const bool clip, - const nnvm::Tuple &variances, - const float nms_threshold, - const bool force_suppress, - const int nms_topk) { - CHECK_EQ(variances.ndim(), 4) << "Variance size must be 4"; - const int num_classes = cls_prob.size(1); - const int num_anchors = cls_prob.size(2); - const int num_batches = cls_prob.size(0); - const DType *p_anchor = anchors.dptr_; - for (int nbatch = 0; nbatch < num_batches; ++nbatch) { - const DType *p_cls_prob = cls_prob.dptr_ + nbatch * num_classes * num_anchors; - const DType *p_loc_pred = loc_pred.dptr_ + nbatch * num_anchors * 4; - DType *p_out = out.dptr_ + nbatch * num_anchors * 6; - int valid_count = 0; - for (int i = 0; i < num_anchors; ++i) { - // find the predicted class id and probability - DType score = -1; - int id = 0; - for (int j = 1; j < num_classes; ++j) { - DType temp = p_cls_prob[j * num_anchors + i]; - if (temp > score) { - score = temp; - id = j; - } - } - if (id > 0 && score < threshold) { - id = 0; - } - if (id > 0) { - // [id, prob, xmin, ymin, xmax, ymax] - p_out[valid_count * 6] = id - 1; // remove background, restore original id - p_out[valid_count * 6 + 1] = (id == 0 ? DType(-1) : score); - int offset = i * 4; - TransformLocations(p_out + valid_count * 6 + 2, p_anchor + offset, - p_loc_pred + offset, clip, variances[0], variances[1], - variances[2], variances[3]); - ++valid_count; - } - } // end iter num_anchors - - if (valid_count < 1 || nms_threshold <= 0 || nms_threshold > 1) continue; - - // sort and apply NMS - Copy(temp_space[nbatch], out[nbatch], out.stream_); - // sort confidence in descend order - std::vector> sorter; - sorter.reserve(valid_count); - for (int i = 0; i < valid_count; ++i) { - sorter.push_back(SortElemDescend(p_out[i * 6 + 1], i)); - } - std::stable_sort(sorter.begin(), sorter.end()); - // re-order output - DType *ptemp = temp_space.dptr_ + nbatch * num_anchors * 6; - int nkeep = static_cast(sorter.size()); - if (nms_topk > 0 && nms_topk < nkeep) { - nkeep = nms_topk; - } - for (int i = 0; i < nkeep; ++i) { - for (int j = 0; j < 6; ++j) { - p_out[i * 6 + j] = ptemp[sorter[i].index * 6 + j]; - } - } - // apply nms - for (int i = 0; i < valid_count; ++i) { - int offset_i = i * 6; - if (p_out[offset_i] < 0) continue; // skip eliminated - for (int j = i + 1; j < valid_count; ++j) { - int offset_j = j * 6; - if (p_out[offset_j] < 0) continue; // skip eliminated - if (force_suppress || (p_out[offset_i] == p_out[offset_j])) { - // when foce_suppress == true or class_id equals - DType iou = CalculateOverlap(p_out + offset_i + 2, p_out + offset_j + 2); - if (iou >= nms_threshold) { - p_out[offset_j] = -1; - } - } - } - } - } // end iter batch -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxDetectionParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxDetectionOp(param); - }); - return op; -} - -Operator* MultiBoxDetectionProp::CreateOperatorEx(Context ctx, - std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxDetectionParam); -MXNET_REGISTER_OP_PROPERTY(MultiBoxDetection, MultiBoxDetectionProp) -.describe("Convert multibox detection predictions.") -.add_argument("cls_prob", "Symbol", "Class probabilities.") -.add_argument("loc_pred", "Symbol", "Location regression predictions.") -.add_argument("anchors", "Symbol", "Multibox prior anchor boxes") -.add_arguments(MultiBoxDetectionParam::__FIELDS__()); -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_detection.cu b/operator/multibox_detection.cu deleted file mode 100644 index dab11ff..0000000 --- a/operator/multibox_detection.cu +++ /dev/null @@ -1,227 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection.cu - * \brief MultiBoxDetection op - * \author Joshua Zhang -*/ -#include "./multibox_detection-inl.h" -#include - -#define MULTIBOX_DETECTION_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__device__ void Clip(DType *value, const DType lower, const DType upper) { - if ((*value) < lower) *value = lower; - if ((*value) > upper) *value = upper; -} - -template -__device__ void CalculateOverlap(const DType *a, const DType *b, DType *iou) { - DType w = max(DType(0), min(a[2], b[2]) - max(a[0], b[0])); - DType h = max(DType(0), min(a[3], b[3]) - max(a[1], b[1])); - DType i = w * h; - DType u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; - (*iou) = u <= 0.f ? static_cast(0) : static_cast(i / u); -} - -template -__global__ void DetectionForwardKernel(DType *out, const DType *cls_prob, - const DType *loc_pred, const DType *anchors, - DType *temp_space, const int num_classes, - const int num_anchors, const float threshold, - const bool clip, const float vx, - const float vy, const float vw, - const float vh, const float nms_threshold, - const bool force_suppress, const int nms_topk) { - const int nbatch = blockIdx.x; // each block for each batch - int index = threadIdx.x; - __shared__ int valid_count; - out += nbatch * num_anchors * 6; - cls_prob += nbatch * num_anchors * num_classes; - loc_pred += nbatch * num_anchors * 4; - - if (index == 0) { - valid_count = 0; - } - __syncthreads(); - - // apply prediction to anchors - for (int i = index; i < num_anchors; i += blockDim.x) { - DType score = -1; - int id = 0; - for (int j = 1; j < num_classes; ++j) { - DType temp = cls_prob[j * num_anchors + i]; - if (temp > score) { - score = temp; - id = j; - } - } - if (id > 0 && score < threshold) { - id = 0; - } - - if (id > 0) { - // valid class - int pos = atomicAdd(&valid_count, 1); - out[pos * 6] = id - 1; // restore original class id - out[pos * 6 + 1] = (id == 0 ? DType(-1) : score); - int offset = i * 4; - DType al = anchors[offset]; - DType at = anchors[offset + 1]; - DType ar = anchors[offset + 2]; - DType ab = anchors[offset + 3]; - DType aw = ar - al; - DType ah = ab - at; - DType ax = (al + ar) / 2.f; - DType ay = (at + ab) / 2.f; - DType ox = loc_pred[offset] * vx * aw + ax; - DType oy = loc_pred[offset + 1] * vy * ah + ay; - DType ow = exp(loc_pred[offset + 2] * vw) * aw / 2; - DType oh = exp(loc_pred[offset + 3] * vh) * ah / 2; - DType xmin = ox - ow; - DType ymin = oy - oh; - DType xmax = ox + ow; - DType ymax = oy + oh; - if (clip) { - Clip(&xmin, DType(0), DType(1)); - Clip(&ymin, DType(0), DType(1)); - Clip(&xmax, DType(0), DType(1)); - Clip(&ymax, DType(0), DType(1)); - } - out[pos * 6 + 2] = xmin; - out[pos * 6 + 3] = ymin; - out[pos * 6 + 4] = xmax; - out[pos * 6 + 5] = ymax; - } - } - __syncthreads(); - - if (valid_count < 1 || nms_threshold <= 0 || nms_threshold > 1) return; - // if (index == 0) printf("%d\n", valid_count); - - // descent sort according to scores - const int size = valid_count; - temp_space += nbatch * num_anchors * 6; - DType *src = out; - DType *dst = temp_space; - for (int width = 2; width < (size << 1); width <<= 1) { - int slices = (size - 1) / (blockDim.x * width) + 1; - int start = width * index * slices; - for (int slice = 0; slice < slices; ++slice) { - if (start >= size) break; - int middle = start + (width >> 1); - if (middle > size) middle = size; - int end = start + width; - if (end > size) end = size; - int i = start; - int j = middle; - for (int k = start; k < end; ++k) { - DType score_i = i < size ? src[i * 6 + 1] : DType(-1); - DType score_j = j < size ? src[j * 6 + 1] : DType(-1); - if (i < middle && (j >= end || score_i > score_j)) { - for (int n = 0; n < 6; ++n) { - dst[k * 6 + n] = src[i * 6 + n]; - } - ++i; - } else { - for (int n = 0; n < 6; ++n) { - dst[k * 6 + n] = src[j * 6 + n]; - } - ++j; - } - } - start += width; - } - __syncthreads(); - src = src == out? temp_space : out; - dst = dst == out? temp_space : out; - } - __syncthreads(); - - if (src == temp_space) { - // copy from temp to out - for (int i = index; i < size * 6; i += blockDim.x) { - out[i] = temp_space[i]; - } - __syncthreads(); - } - - // keep top k detections - int ntop = size; - if (nms_topk > 0 && nms_topk < ntop) { - ntop = nms_topk; - for (int i = ntop + index; i < size; i += blockDim.x) { - out[i * 6] = -1; - } - __syncthreads(); - } - - // apply NMS - for (int compare_pos = 0; compare_pos < ntop; ++compare_pos) { - DType compare_id = out[compare_pos * 6]; - if (compare_id < 0) continue; // not a valid positive detection, skip - DType *compare_loc_ptr = out + compare_pos * 6 + 2; - for (int i = compare_pos + index + 1; i < ntop; i += blockDim.x) { - DType class_id = out[i * 6]; - if (class_id < 0) continue; - if (force_suppress || (class_id == compare_id)) { - DType iou; - CalculateOverlap(compare_loc_ptr, out + i * 6 + 2, &iou); - if (iou >= nms_threshold) { - out[i * 6] = -1; - } - } - } - __syncthreads(); - } -} -} // namespace cuda - -template -inline void MultiBoxDetectionForward(const Tensor &out, - const Tensor &cls_prob, - const Tensor &loc_pred, - const Tensor &anchors, - const Tensor &temp_space, - const float threshold, - const bool clip, - const nnvm::Tuple &variances, - const float nms_threshold, - const bool force_suppress, - const int nms_topk) { - CHECK_EQ(variances.ndim(), 4) << "Variance size must be 4"; - const int num_classes = cls_prob.size(1); - const int num_anchors = cls_prob.size(2); - const int num_batches = cls_prob.size(0); - const int num_threads = cuda::kMaxThreadsPerBlock; - int num_blocks = num_batches; - cuda::CheckLaunchParam(num_blocks, num_threads, "MultiBoxDetection Forward"); - cudaStream_t stream = Stream::GetStream(out.stream_); - cuda::DetectionForwardKernel<<>>(out.dptr_, - cls_prob.dptr_, loc_pred.dptr_, anchors.dptr_, temp_space.dptr_, - num_classes, num_anchors, threshold, clip, - variances[0], variances[1], variances[2], variances[3], - nms_threshold, force_suppress, nms_topk); - MULTIBOX_DETECTION_CUDA_CHECK(cudaPeekAtLastError()); -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxDetectionParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxDetectionOp(param); - }); - return op; -} -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_prior-inl.h b/operator/multibox_prior-inl.h deleted file mode 100644 index 72a30b6..0000000 --- a/operator/multibox_prior-inl.h +++ /dev/null @@ -1,203 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior-inl.h - * \brief generate multibox prior boxes - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" - - -namespace mxnet { -namespace op { - -namespace mshadow_op { -struct clip_zero_one { - template - MSHADOW_XINLINE static DType Map(DType a) { - if (a < 0.f) return DType(0.f); - if (a > 1.f) return DType(1.f); - return DType(a); - } -}; // struct clip_zero_one -} // namespace mshadow_op - -namespace mboxprior_enum { -enum MultiBoxPriorOpInputs {kData}; -enum MultiBoxPriorOpOutputs {kOut}; -} // namespace mboxprior_enum - -struct MultiBoxPriorParam : public dmlc::Parameter { - nnvm::Tuple sizes; - nnvm::Tuple ratios; - bool clip; - nnvm::Tuple steps; - nnvm::Tuple offsets; - DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { - DMLC_DECLARE_FIELD(sizes).set_default({1.0f}) - .describe("List of sizes of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(ratios).set_default({1.0f}) - .describe("List of aspect ratios of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); - DMLC_DECLARE_FIELD(steps).set_default({-1.f, -1.f}) - .describe("Priorbox step across y and x, -1 for auto calculation."); - DMLC_DECLARE_FIELD(offsets).set_default({0.5f, 0.5f}) - .describe("Priorbox center offsets, y and x respectively"); - } -}; // struct MultiBoxPriorParam - -template -class MultiBoxPriorOp : public Operator { - public: - explicit MultiBoxPriorOp(MultiBoxPriorParam param) - : clip_(param.clip), sizes_(param.sizes.begin(), param.sizes.end()), - ratios_(param.ratios.begin(), param.ratios.end()), - steps_(param.steps.begin(), param.steps.end()), - offsets_(param.offsets.begin(), param.offsets.end()) { - CHECK_GT(sizes_.size(), 0); - CHECK_GT(ratios_.size(), 0); - CHECK_EQ(steps_.size(), 2); - CHECK_EQ(offsets_.size(), 2); - CHECK_GE(offsets_[0], 0.f); - CHECK_LE(offsets_[0], 1.f); - CHECK_GE(offsets_[1], 0.f); - CHECK_LE(offsets_[1], 1.f); - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(static_cast(in_data.size()), 1); - CHECK_EQ(out_data.size(), 1); - Stream *s = ctx.get_stream(); - Tensor out; - // TODO(zhreshold): this implementation is to be compliant to original ssd in caffe - // The prior boxes could be implemented in more versatile ways - // since input sizes are same in each batch, we could share MultiBoxPrior - const int num_sizes = static_cast(sizes_.size()); - const int num_ratios = static_cast(ratios_.size()); - const int num_anchors = num_sizes - 1 + num_ratios; // anchors per location - int in_height = in_data[mboxprior_enum::kData].size(2); - int in_width = in_data[mboxprior_enum::kData].size(3); - Shape<2> oshape = Shape2(num_anchors * in_width * in_height, 4); - out = out_data[mboxprior_enum::kOut].get_with_shape(oshape, s); - CHECK_GE(steps_[0] * steps_[1], 0) << "Must specify both step_y and step_x"; - if (steps_[0] <= 0 || steps_[1] <= 0) { - // estimate using layer shape - steps_[0] = 1.f / in_height; - steps_[1] = 1.f / in_width; - } - MultiBoxPriorForward(out, sizes_, ratios_, in_width, in_height, steps_, offsets_); - - if (clip_) { - Assign(out, req[mboxprior_enum::kOut], F(out)); - } - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor grad = in_grad[mboxprior_enum::kData].FlatTo2D(s); - grad = 0.f; - } - - private: - bool clip_; - std::vector sizes_; - std::vector ratios_; - std::vector steps_; - std::vector offsets_; -}; // class MultiBoxPriorOp - -template -Operator *CreateOp(MultiBoxPriorParam, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxPriorProp: public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - std::vector ListArguments() const override { - return {"data"}; - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 1) << "Inputs: [data]" << in_shape->size(); - TShape dshape = in_shape->at(mboxprior_enum::kData); - CHECK_GE(dshape.ndim(), 4) << "Input data should be 4D: batch-channel-y-x"; - int in_height = dshape[2]; - CHECK_GT(in_height, 0) << "Input height should > 0"; - int in_width = dshape[3]; - CHECK_GT(in_width, 0) << "Input width should > 0"; - // since input sizes are same in each batch, we could share MultiBoxPrior - TShape oshape = TShape(3); - int num_sizes = param_.sizes.ndim(); - int num_ratios = param_.ratios.ndim(); - oshape[0] = 1; - oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1); - oshape[2] = 4; - out_shape->clear(); - out_shape->push_back(oshape); - CHECK_EQ(param_.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)"; - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new MultiBoxPriorProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "MultiBoxPrior"; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxPriorParam param_; -}; // class MultiBoxPriorProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ diff --git a/operator/multibox_prior.cc b/operator/multibox_prior.cc deleted file mode 100644 index 13986f8..0000000 --- a/operator/multibox_prior.cc +++ /dev/null @@ -1,84 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior.cc - * \brief generate multibox prior boxes cpu implementation - * \author Joshua Zhang -*/ - -#include "./multibox_prior-inl.h" - -namespace mshadow { -template -inline void MultiBoxPriorForward(const Tensor &out, - const std::vector &sizes, - const std::vector &ratios, - const int in_width, const int in_height, - const std::vector &steps, - const std::vector &offsets) { - const float step_x = steps[1]; - const float step_y = steps[0]; - const int num_sizes = static_cast(sizes.size()); - const int num_ratios = static_cast(ratios.size()); - int count = 0; - - for (int r = 0; r < in_height; ++r) { - float center_y = (r + offsets[0]) * step_y; - for (int c = 0; c < in_width; ++c) { - float center_x = (c + offsets[1]) * step_x; - // ratio = 1, various sizes - for (int i = 0; i < num_sizes; ++i) { - float size = sizes[i]; - float w = size / 2; - float h = size / 2; - out[count][0] = center_x - w; // xmin - out[count][1] = center_y - h; // ymin - out[count][2] = center_x + w; // xmax - out[count][3] = center_y + h; // ymax - ++count; - } - // various ratios, size = min_size = size[0] - float size = sizes[0]; - for (int j = 1; j < num_ratios; ++j) { - float ratio = sqrtf(ratios[j]); - float w = size * ratio / 2; - float h = size / ratio / 2; - out[count][0] = center_x - w; // xmin - out[count][1] = center_y - h; // ymin - out[count][2] = center_x + w; // xmax - out[count][3] = center_y + h; // ymax - ++count; - } - } - } -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(MultiBoxPriorParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxPriorOp(param); - }); - return op; -} - -Operator* MultiBoxPriorProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxPriorParam); - -MXNET_REGISTER_OP_PROPERTY(MultiBoxPrior, MultiBoxPriorProp) -.add_argument("data", "Symbol", "Input data.") -.add_arguments(MultiBoxPriorParam::__FIELDS__()) -.describe("Generate prior(anchor) boxes from data, sizes and ratios."); - -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_prior.cu b/operator/multibox_prior.cu deleted file mode 100644 index a3f2cc2..0000000 --- a/operator/multibox_prior.cu +++ /dev/null @@ -1,98 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior.cu - * \brief generate multibox prior boxes cuda kernels - * \author Joshua Zhang -*/ - -#include "./multibox_prior-inl.h" -#include - -#define MULTIBOXPRIOR_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__global__ void AssignPriors(DType *out, const float size, - const float sqrt_ratio, const int in_width, - const int in_height, const float step_x, - const float step_y, const float center_offy, - const float center_offx, const int stride, - const int offset) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= in_width * in_height) return; - int r = index / in_width; - int c = index % in_width; - float center_x = (c + center_offx) * step_x; - float center_y = (r + center_offy) * step_y; - float w = size * sqrt_ratio / 2; // half width - float h = size / sqrt_ratio / 2; // half height - DType *ptr = out + index * stride + 4 * offset; - *(ptr++) = center_x - w; // xmin - *(ptr++) = center_y - h; // ymin - *(ptr++) = center_x + w; // xmax - *(ptr++) = center_y + h; // ymax -} -} // namespace cuda - -template -inline void MultiBoxPriorForward(const Tensor &out, - const std::vector &sizes, - const std::vector &ratios, - const int in_width, const int in_height, - const std::vector &steps, - const std::vector &offsets) { - CHECK_EQ(out.CheckContiguous(), true); - cudaStream_t stream = Stream::GetStream(out.stream_); - DType *out_ptr = out.dptr_; - const float step_x = steps[1]; - const float step_y = steps[0]; - const float offset_x = offsets[1]; - const float offset_y = offsets[0]; - const int num_sizes = static_cast(sizes.size()); - const int num_ratios = static_cast(ratios.size()); - - const int num_thread = cuda::kMaxThreadsPerBlock; - dim3 dimBlock(num_thread); - dim3 dimGrid((in_width * in_height - 1) / num_thread + 1); - cuda::CheckLaunchParam(dimGrid, dimBlock, "MultiBoxPrior Forward"); - - const int stride = 4 * (num_sizes + num_ratios - 1); - int offset = 0; - // ratio = 1, various sizes - for (int i = 0; i < num_sizes; ++i) { - cuda::AssignPriors<<>>(out_ptr, - sizes[i], 1.f, in_width, in_height, step_x, step_y, offset_y, offset_x, stride, offset); - ++offset; - } - MULTIBOXPRIOR_CUDA_CHECK(cudaPeekAtLastError()); - - // size = sizes[0], various ratios - for (int j = 1; j < num_ratios; ++j) { - cuda::AssignPriors<<>>(out_ptr, - sizes[0], sqrtf(ratios[j]), in_width, in_height, step_x, step_y, - offset_y, offset_x, stride, offset); - ++offset; - } - MULTIBOXPRIOR_CUDA_CHECK(cudaPeekAtLastError()); -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(MultiBoxPriorParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxPriorOp(param); - }); - return op; -} - -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_target-inl.h b/operator/multibox_target-inl.h deleted file mode 100644 index 134dc68..0000000 --- a/operator/multibox_target-inl.h +++ /dev/null @@ -1,261 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target-inl.h - * \brief - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { - -namespace mshadow_op { -struct safe_divide { - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - if (b == DType(0.0f)) return DType(0.0f); - return DType(a / b); - } -}; // struct safe_divide -} // namespace mshadow_op - -namespace mboxtarget_enum { -enum MultiBoxTargetOpInputs {kAnchor, kLabel, kClsPred}; -enum MultiBoxTargetOpOutputs {kLoc, kLocMask, kCls}; -enum MultiBoxTargetOpResource {kTempSpace}; -} // namespace mboxtarget_enum - -struct MultiBoxTargetParam : public dmlc::Parameter { - float overlap_threshold; - float ignore_label; - float negative_mining_ratio; - float negative_mining_thresh; - int minimum_negative_samples; - nnvm::Tuple variances; - DMLC_DECLARE_PARAMETER(MultiBoxTargetParam) { - DMLC_DECLARE_FIELD(overlap_threshold).set_default(0.5f) - .describe("Anchor-GT overlap threshold to be regarded as a possitive match."); - DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f) - .describe("Label for ignored anchors."); - DMLC_DECLARE_FIELD(negative_mining_ratio).set_default(-1.0f) - .describe("Max negative to positive samples ratio, use -1 to disable mining"); - DMLC_DECLARE_FIELD(negative_mining_thresh).set_default(0.5f) - .describe("Threshold used for negative mining."); - DMLC_DECLARE_FIELD(minimum_negative_samples).set_default(0) - .describe("Minimum number of negative samples."); - DMLC_DECLARE_FIELD(variances).set_default({0.1f, 0.1f, 0.2f, 0.2f}) - .describe("Variances to be encoded in box regression target."); - } -}; // struct MultiBoxTargetParam - -template -class MultiBoxTargetOp : public Operator { - public: - explicit MultiBoxTargetOp(MultiBoxTargetParam param) { - this->param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow_op; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3); - CHECK_EQ(out_data.size(), 3); - Stream *s = ctx.get_stream(); - Tensor anchors = in_data[mboxtarget_enum::kAnchor] - .get_with_shape( - Shape2(in_data[mboxtarget_enum::kAnchor].size(1), 4), s); - Tensor labels = in_data[mboxtarget_enum::kLabel] - .get(s); - Tensor cls_preds = in_data[mboxtarget_enum::kClsPred] - .get(s); - Tensor loc_target = out_data[mboxtarget_enum::kLoc] - .get(s); - Tensor loc_mask = out_data[mboxtarget_enum::kLocMask] - .get(s); - Tensor cls_target = out_data[mboxtarget_enum::kCls] - .get(s); - - index_t num_batches = labels.size(0); - index_t num_anchors = anchors.size(0); - index_t num_labels = labels.size(1); - // TODO(zhreshold): use maximum valid ground-truth in batch rather than # in dataset - Shape<4> temp_shape = Shape4(11, num_batches, num_anchors, num_labels); - Tensor temp_space = ctx.requested[mboxtarget_enum::kTempSpace] - .get_space_typed(temp_shape, s); - loc_target = 0.f; - loc_mask = 0.0f; - cls_target = param_.ignore_label; - temp_space = -1.0f; - CHECK_EQ(anchors.CheckContiguous(), true); - CHECK_EQ(labels.CheckContiguous(), true); - CHECK_EQ(cls_preds.CheckContiguous(), true); - CHECK_EQ(loc_target.CheckContiguous(), true); - CHECK_EQ(loc_mask.CheckContiguous(), true); - CHECK_EQ(cls_target.CheckContiguous(), true); - CHECK_EQ(temp_space.CheckContiguous(), true); - - // compute overlaps - // TODO(zhreshold): squeeze temporary memory space - // temp_space, 0:out, 1:l1, 2:t1, 3:r1, 4:b1, 5:l2, 6:t2, 7:r2, 8:b2 - // 9: intersection, 10:union - temp_space[1] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 0, 1), -1, - num_batches), 2, num_labels); - temp_space[2] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 1, 2), -1, - num_batches), 2, num_labels); - temp_space[3] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 2, 3), -1, - num_batches), 2, num_labels); - temp_space[4] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 3, 4), -1, - num_batches), 2, num_labels); - Shape<3> temp_reshape = Shape3(num_batches, 1, num_labels); - temp_space[5] = broadcast_keepdim(reshape(slice<2>(labels, 1, 2), temp_reshape), 1, - num_anchors); - temp_space[6] = broadcast_keepdim(reshape(slice<2>(labels, 2, 3), temp_reshape), 1, - num_anchors); - temp_space[7] = broadcast_keepdim(reshape(slice<2>(labels, 3, 4), temp_reshape), 1, - num_anchors); - temp_space[8] = broadcast_keepdim(reshape(slice<2>(labels, 4, 5), temp_reshape), 1, - num_anchors); - temp_space[9] = F(ScalarExp(0.0f), - F(temp_space[3], temp_space[7]) - F(temp_space[1], temp_space[5])) - * F(ScalarExp(0.0f), - F(temp_space[4], temp_space[8]) - F(temp_space[2], temp_space[6])); - temp_space[10] = (temp_space[3] - temp_space[1]) * (temp_space[4] - temp_space[2]) - + (temp_space[7] - temp_space[5]) * (temp_space[8] - temp_space[6]) - - temp_space[9]; - temp_space[0] = F(temp_space[9], temp_space[10]); - - MultiBoxTargetForward(loc_target, loc_mask, cls_target, - anchors, labels, cls_preds, temp_space, - param_.overlap_threshold, - param_.ignore_label, - param_.negative_mining_ratio, - param_.negative_mining_thresh, - param_.minimum_negative_samples, - param_.variances); - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor grad = in_grad[mboxtarget_enum::kClsPred].FlatTo2D(s); - grad = 0.f; -} - - private: - MultiBoxTargetParam param_; -}; // class MultiBoxTargetOp - -template -Operator* CreateOp(MultiBoxTargetParam param, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxTargetProp : public OperatorProperty { - public: - std::vector ListArguments() const override { - return {"anchor", "label", "cls_pred"}; - } - - std::vector ListOutputs() const override { - return {"loc_target", "loc_mask", "cls_target"}; - } - - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 3) << "Input: [anchor, label, clsPred]"; - TShape ashape = in_shape->at(mboxtarget_enum::kAnchor); - CHECK_EQ(ashape.ndim(), 3) << "Anchor should be batch shared N*4 tensor"; - CHECK_EQ(ashape[0], 1) << "Anchors are shared across batches, first dim=1"; - CHECK_GT(ashape[1], 0) << "Number boxes should > 0"; - CHECK_EQ(ashape[2], 4) << "Box dimension should be 4: [xmin-ymin-xmax-ymax]"; - TShape lshape = in_shape->at(mboxtarget_enum::kLabel); - CHECK_EQ(lshape.ndim(), 3) << "Label should be [batch-num_labels-(>=5)] tensor"; - CHECK_GT(lshape[1], 0) << "Padded label should > 0"; - CHECK_GE(lshape[2], 5) << "Label width must >=5"; - TShape pshape = in_shape->at(mboxtarget_enum::kClsPred); - CHECK_EQ(pshape.ndim(), 3) << "Prediction: [nbatch-num_classes-num_anchors]"; - CHECK_EQ(pshape[2], ashape[1]) << "Number of anchors mismatch"; - TShape loc_shape = Shape2(lshape[0], ashape.Size()); // batch - (num_box * 4) - TShape lm_shape = loc_shape; - TShape label_shape = Shape2(lshape[0], ashape[1]); // batch - num_box - out_shape->clear(); - out_shape->push_back(loc_shape); - out_shape->push_back(lm_shape); - out_shape->push_back(label_shape); - return true; - } - - OperatorProperty* Copy() const override { - MultiBoxTargetProp* MultiBoxTarget_sym = new MultiBoxTargetProp(); - MultiBoxTarget_sym->param_ = this->param_; - return MultiBoxTarget_sym; - } - - std::string TypeString() const override { - return "MultiBoxTarget"; - } - - // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {}; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxTargetParam param_; -}; // class MultiBoxTargetProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ diff --git a/operator/multibox_target.cc b/operator/multibox_target.cc deleted file mode 100644 index ab717b9..0000000 --- a/operator/multibox_target.cc +++ /dev/null @@ -1,292 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target.cc - * \brief MultiBoxTarget op - * \author Joshua Zhang -*/ -#include "./multibox_target-inl.h" -#include "./mshadow_op.h" -#include - -namespace mshadow { -template -inline void AssignLocTargets(const DType *anchor, const DType *l, DType *dst, - const float vx, const float vy, - const float vw, const float vh) { - float al = *(anchor); - float at = *(anchor+1); - float ar = *(anchor+2); - float ab = *(anchor+3); - float aw = ar - al; - float ah = ab - at; - float ax = (al + ar) * 0.5; - float ay = (at + ab) * 0.5; - float gl = *(l); - float gt = *(l+1); - float gr = *(l+2); - float gb = *(l+3); - float gw = gr - gl; - float gh = gb - gt; - float gx = (gl + gr) * 0.5; - float gy = (gt + gb) * 0.5; - *(dst) = DType((gx - ax) / aw / vx); - *(dst+1) = DType((gy - ay) / ah / vy); - *(dst+2) = DType(std::log(gw / aw) / vw); - *(dst+3) = DType(std::log(gh / ah) / vh); -} - -struct SortElemDescend { - float value; - int index; - - SortElemDescend(float v, int i) { - value = v; - index = i; - } - - bool operator<(const SortElemDescend &other) const { - return value > other.value; - } -}; - -template -inline void MultiBoxTargetForward(const Tensor &loc_target, - const Tensor &loc_mask, - const Tensor &cls_target, - const Tensor &anchors, - const Tensor &labels, - const Tensor &cls_preds, - const Tensor &temp_space, - const float overlap_threshold, - const float background_label, - const float negative_mining_ratio, - const float negative_mining_thresh, - const int minimum_negative_samples, - const nnvm::Tuple &variances) { - const DType *p_anchor = anchors.dptr_; - const int num_batches = labels.size(0); - const int num_labels = labels.size(1); - const int label_width = labels.size(2); - const int num_anchors = anchors.size(0); - CHECK_EQ(variances.ndim(), 4); - for (int nbatch = 0; nbatch < num_batches; ++nbatch) { - const DType *p_label = labels.dptr_ + nbatch * num_labels * label_width; - const DType *p_overlaps = temp_space.dptr_ + nbatch * num_anchors * num_labels; - int num_valid_gt = 0; - for (int i = 0; i < num_labels; ++i) { - if (static_cast(*(p_label + i * label_width)) == -1.0f) { - CHECK_EQ(static_cast(*(p_label + i * label_width + 1)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * label_width + 2)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * label_width + 3)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * label_width + 4)), -1.0f); - break; - } - ++num_valid_gt; - } // end iterate labels - - if (num_valid_gt > 0) { - std::vector gt_flags(num_valid_gt, false); - std::vector> max_matches(num_anchors, - std::pair(-1.0f, -1)); - std::vector anchor_flags(num_anchors, -1); // -1 means don't care - int num_positive = 0; - while (std::find(gt_flags.begin(), gt_flags.end(), false) != gt_flags.end()) { - // ground-truths not fully matched - int best_anchor = -1; - int best_gt = -1; - float max_overlap = 1e-6; // start with a very small positive overlap - for (int j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - const DType *pp_overlaps = p_overlaps + j * num_labels; - for (int k = 0; k < num_valid_gt; ++k) { - if (gt_flags[k]) { - continue; // already matched this gt - } - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_overlap) { - best_anchor = j; - best_gt = k; - max_overlap = iou; - } - } - } - - if (best_anchor == -1) { - CHECK_EQ(best_gt, -1); - break; // no more good match - } else { - CHECK_EQ(max_matches[best_anchor].first, -1.0f); - CHECK_EQ(max_matches[best_anchor].second, -1); - max_matches[best_anchor].first = max_overlap; - max_matches[best_anchor].second = best_gt; - num_positive += 1; - // mark as visited - gt_flags[best_gt] = true; - anchor_flags[best_anchor] = 1; - } - } // end while - - if (overlap_threshold > 0) { - // find positive matches based on overlaps - for (int j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - const DType *pp_overlaps = p_overlaps + j * num_labels; - int best_gt = -1; - float max_iou = -1.0f; - for (int k = 0; k < num_valid_gt; ++k) { - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_iou) { - best_gt = k; - max_iou = iou; - } - } - if (best_gt != -1) { - CHECK_EQ(max_matches[j].first, -1.0f); - CHECK_EQ(max_matches[j].second, -1); - max_matches[j].first = max_iou; - max_matches[j].second = best_gt; - if (max_iou > overlap_threshold) { - num_positive += 1; - // mark as visited - gt_flags[best_gt] = true; - anchor_flags[j] = 1; - } - } - } // end iterate anchors - } - - if (negative_mining_ratio > 0) { - const int num_classes = cls_preds.size(1); - DType *p_cls_preds = cls_preds.dptr_ + nbatch * num_classes * num_anchors; - CHECK_GT(negative_mining_thresh, 0); - int num_negative = num_positive * negative_mining_ratio; - if (num_negative > (num_anchors - num_positive)) { - num_negative = num_anchors - num_positive; - } - if (num_negative > 0) { - // use negative mining, pick up "best" negative samples - std::vector temp; - temp.reserve(num_anchors - num_positive); - for (int j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - if (max_matches[j].first < 0) { - // not yet calculated - const DType *pp_overlaps = p_overlaps + j * num_labels; - int best_gt = -1; - float max_iou = -1.0f; - for (int k = 0; k < num_valid_gt; ++k) { - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_iou) { - best_gt = k; - max_iou = iou; - } - } - if (best_gt != -1) { - CHECK_EQ(max_matches[j].first, -1.0f); - CHECK_EQ(max_matches[j].second, -1); - max_matches[j].first = max_iou; - max_matches[j].second = best_gt; - } - } - if (max_matches[j].first < negative_mining_thresh && - anchor_flags[j] == -1) { - // calcuate class predictions - DType max_val = p_cls_preds[j]; - for (int k = 1; k < num_classes; ++k) { - DType tmp = p_cls_preds[j + num_anchors * k]; - if (tmp > max_val) max_val = tmp; - } - DType sum = 0.f; - for (int k = 0; k < num_classes; ++k) { - DType tmp = p_cls_preds[j + num_anchors * k]; - sum += std::exp(tmp - max_val); - } - DType prob = std::exp(p_cls_preds[j] - max_val) / sum; - // loss should be -log(x), but value does not matter, skip log - temp.push_back(SortElemDescend(-prob, j)); - } - } // end iterate anchors - - CHECK_GE(temp.size(), num_negative); - std::stable_sort(temp.begin(), temp.end()); - for (int i = 0; i < num_negative; ++i) { - anchor_flags[temp[i].index] = 0; // mark as negative sample - } - } - } else { - // use all negative samples - for (int i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] != 1) { - anchor_flags[i] = 0; - } - } - } - - // assign training targets - DType *p_loc_target = loc_target.dptr_ + nbatch * num_anchors * 4; - DType *p_loc_mask = loc_mask.dptr_ + nbatch * num_anchors * 4; - DType *p_cls_target = cls_target.dptr_ + nbatch * num_anchors; - for (int i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] == 1) { - // positive sample - CHECK_GE(max_matches[i].second, 0); - // 0 reserved for background - *(p_cls_target + i) = *(p_label + label_width * max_matches[i].second) + 1; - int offset = i * 4; - *(p_loc_mask + offset) = 1; - *(p_loc_mask + offset + 1) = 1; - *(p_loc_mask + offset + 2) = 1; - *(p_loc_mask + offset + 3) = 1; - AssignLocTargets(p_anchor + i * 4, - p_label + label_width * max_matches[i].second + 1, p_loc_target + offset, - variances[0], variances[1], variances[2], variances[3]); - } else if (anchor_flags[i] == 0) { - // negative sample - *(p_cls_target + i) = 0; - int offset = i * 4; - *(p_loc_mask + offset) = 0; - *(p_loc_mask + offset + 1) = 0; - *(p_loc_mask + offset + 2) = 0; - *(p_loc_mask + offset + 3) = 0; - } - } // end iterate anchors - } - } // end iterate batches -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxTargetParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxTargetOp(param); - }); - return op; -} - -Operator* MultiBoxTargetProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxTargetParam); -MXNET_REGISTER_OP_PROPERTY(MultiBoxTarget, MultiBoxTargetProp) -.describe("Compute Multibox training targets") -.add_argument("anchor", "Symbol", "Generated anchor boxes.") -.add_argument("label", "Symbol", "Object detection labels.") -.add_argument("cls_pred", "Symbol", "Class predictions.") -.add_arguments(MultiBoxTargetParam::__FIELDS__()); -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_target.cu b/operator/multibox_target.cu deleted file mode 100644 index adcfcf2..0000000 --- a/operator/multibox_target.cu +++ /dev/null @@ -1,411 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target.cu - * \brief MultiBoxTarget op - * \author Joshua Zhang -*/ -#include "./multibox_target-inl.h" -#include - -#define MULTIBOX_TARGET_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__global__ void InitGroundTruthFlags(DType *gt_flags, const DType *labels, - const int num_batches, - const int num_labels, - const int label_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_batches * num_labels) return; - int b = index / num_labels; - int l = index % num_labels; - if (*(labels + b * num_labels * label_width + l * label_width) == -1.f) { - *(gt_flags + b * num_labels + l) = 0; - } else { - *(gt_flags + b * num_labels + l) = 1; - } -} - -template -__global__ void FindBestMatches(DType *best_matches, DType *gt_flags, - DType *anchor_flags, const DType *overlaps, - const int num_anchors, const int num_labels) { - int nbatch = blockIdx.x; - gt_flags += nbatch * num_labels; - overlaps += nbatch * num_anchors * num_labels; - best_matches += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - const int num_threads = kMaxThreadsPerBlock; - __shared__ int max_indices_y[kMaxThreadsPerBlock]; - __shared__ int max_indices_x[kMaxThreadsPerBlock]; - __shared__ float max_values[kMaxThreadsPerBlock]; - - while (1) { - // check if all done. - bool finished = true; - for (int i = 0; i < num_labels; ++i) { - if (gt_flags[i] > .5) { - finished = false; - break; - } - } - if (finished) break; // all done. - - // finding max indices in different threads - int max_x = -1; - int max_y = -1; - DType max_value = 1e-6; // start with very small overlap - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] > .5) continue; - for (int j = 0; j < num_labels; ++j) { - if (gt_flags[j] > .5) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_value) { - max_x = j; - max_y = i; - max_value = temp; - } - } - } - } - max_indices_x[threadIdx.x] = max_x; - max_indices_y[threadIdx.x] = max_y; - max_values[threadIdx.x] = max_value; - __syncthreads(); - - if (threadIdx.x == 0) { - // merge results and assign best match - int max_x = -1; - int max_y = -1; - DType max_value = -1; - for (int k = 0; k < num_threads; ++k) { - if (max_indices_y[k] < 0 || max_indices_x[k] < 0) continue; - float temp = max_values[k]; - if (temp > max_value) { - max_x = max_indices_x[k]; - max_y = max_indices_y[k]; - max_value = temp; - } - } - if (max_x >= 0 && max_y >= 0) { - best_matches[max_y] = max_x; - // mark flags as visited - gt_flags[max_x] = 0.f; - anchor_flags[max_y] = 1.f; - } else { - // no more good matches - for (int i = 0; i < num_labels; ++i) { - gt_flags[i] = 0.f; - } - } - } - __syncthreads(); - } -} - -template -__global__ void FindGoodMatches(DType *best_matches, DType *anchor_flags, - const DType *overlaps, const int num_anchors, - const int num_labels, - const float overlap_threshold) { - int nbatch = blockIdx.x; - overlaps += nbatch * num_anchors * num_labels; - best_matches += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - const int num_threads = kMaxThreadsPerBlock; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] < 0) { - int idx = -1; - float max_value = -1.f; - for (int j = 0; j < num_labels; ++j) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_value) { - max_value = temp; - idx = j; - } - } - if (max_value > overlap_threshold && (idx >= 0)) { - best_matches[i] = idx; - anchor_flags[i] = 0.9f; - } - } - } -} - -template -__global__ void UseAllNegatives(DType *anchor_flags, const int num) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num) return; - if (anchor_flags[idx] < 0.5) { - anchor_flags[idx] = 0; // regard all non-positive as negatives - } -} - -template -__global__ void NegativeMining(const DType *overlaps, const DType *cls_preds, - DType *anchor_flags, DType *buffer, - const float negative_mining_ratio, - const float negative_mining_thresh, - const int minimum_negative_samples, - const int num_anchors, - const int num_labels, const int num_classes) { - int nbatch = blockIdx.x; - overlaps += nbatch * num_anchors * num_labels; - cls_preds += nbatch * num_classes * num_anchors; - anchor_flags += nbatch * num_anchors; - buffer += nbatch * num_anchors * 3; - const int num_threads = kMaxThreadsPerBlock; - int num_positive; - __shared__ int num_negative; - - if (threadIdx.x == 0) { - num_positive = 0; - for (int i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] > .5) { - ++num_positive; - } - } - num_negative = num_positive * negative_mining_ratio; - if (num_negative < minimum_negative_samples) { - num_negative = minimum_negative_samples; - } - if (num_negative > (num_anchors - num_positive)) { - num_negative = num_anchors - num_positive; - } - } - __syncthreads(); - - if (num_negative < 1) return; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - buffer[i] = -1.f; - if (anchor_flags[i] < 0) { - // compute max class prediction score - DType max_val = cls_preds[i]; - for (int j = 1; j < num_classes; ++j) { - DType temp = cls_preds[i + num_anchors * j]; - if (temp > max_val) max_val = temp; - } - DType sum = 0.f; - for (int j = 0; j < num_classes; ++j) { - DType temp = cls_preds[i + num_anchors * j]; - sum += exp(temp - max_val); - } - DType prob = exp(cls_preds[i] - max_val) / sum; - DType max_iou = -1.f; - for (int j = 0; j < num_labels; ++j) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_iou) max_iou = temp; - } - if (max_iou < negative_mining_thresh) { - // only do it for anchors with iou < thresh - buffer[i] = -prob; // -log(x) actually, but value does not matter - } - } - } - __syncthreads(); - - // descend merge sorting for negative mining - DType *index_src = buffer + num_anchors; - DType *index_dst = buffer + num_anchors * 2; - DType *src = index_src; - DType *dst = index_dst; - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - index_src[i] = i; - } - __syncthreads(); - - for (int width = 2; width < (num_anchors << 1); width <<= 1) { - int slices = (num_anchors - 1) / (num_threads * width) + 1; - int start = width * threadIdx.x * slices; - for (int slice = 0; slice < slices; ++slice) { - if (start >= num_anchors) break; - int middle = start + (width >> 1); - if (num_anchors < middle) middle = num_anchors; - int end = start + width; - if (num_anchors < end) end = num_anchors; - int i = start; - int j = middle; - for (int k = start; k < end; ++k) { - int idx_i = static_cast(src[i]); - int idx_j = static_cast(src[j]); - if (i < middle && (j >= end || buffer[idx_i] > buffer[idx_j])) { - dst[k] = src[i]; - ++i; - } else { - dst[k] = src[j]; - ++j; - } - } - start += width; - } - __syncthreads(); - // swap src/dst - src = src == index_src? index_dst : index_src; - dst = dst == index_src? index_dst : index_src; - } - __syncthreads(); - - for (int i = threadIdx.x; i < num_negative; i += num_threads) { - int idx = static_cast(src[i]); - if (anchor_flags[idx] < 0) { - anchor_flags[idx] = 0; - } - } -} - -template -__global__ void AssignTrainigTargets(DType *loc_target, DType *loc_mask, - DType *cls_target, DType *anchor_flags, - DType *best_matches, DType *labels, - DType *anchors, const int num_anchors, - const int num_labels, const int label_width, - const float vx, const float vy, - const float vw, const float vh) { - const int nbatch = blockIdx.x; - loc_target += nbatch * num_anchors * 4; - loc_mask += nbatch * num_anchors * 4; - cls_target += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - best_matches += nbatch * num_anchors; - labels += nbatch * num_labels * label_width; - const int num_threads = kMaxThreadsPerBlock; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] > 0.5) { - // positive sample - int offset_l = static_cast(best_matches[i]) * label_width; - cls_target[i] = labels[offset_l] + 1; // 0 reserved for background - int offset = i * 4; - loc_mask[offset] = 1; - loc_mask[offset + 1] = 1; - loc_mask[offset + 2] = 1; - loc_mask[offset + 3] = 1; - // regression targets - float al = anchors[offset]; - float at = anchors[offset + 1]; - float ar = anchors[offset + 2]; - float ab = anchors[offset + 3]; - float aw = ar - al; - float ah = ab - at; - float ax = (al + ar) * 0.5; - float ay = (at + ab) * 0.5; - float gl = labels[offset_l + 1]; - float gt = labels[offset_l + 2]; - float gr = labels[offset_l + 3]; - float gb = labels[offset_l + 4]; - float gw = gr - gl; - float gh = gb - gt; - float gx = (gl + gr) * 0.5; - float gy = (gt + gb) * 0.5; - loc_target[offset] = DType((gx - ax) / aw / vx); // xmin - loc_target[offset + 1] = DType((gy - ay) / ah / vy); // ymin - loc_target[offset + 2] = DType(log(gw / aw) / vw); // xmax - loc_target[offset + 3] = DType(log(gh / ah) / vh); // ymax - } else if (anchor_flags[i] < 0.5 && anchor_flags[i] > -0.5) { - // background - cls_target[i] = 0; - } - } -} -} // namespace cuda - -template -inline void MultiBoxTargetForward(const Tensor &loc_target, - const Tensor &loc_mask, - const Tensor &cls_target, - const Tensor &anchors, - const Tensor &labels, - const Tensor &cls_preds, - const Tensor &temp_space, - const float overlap_threshold, - const float background_label, - const float negative_mining_ratio, - const float negative_mining_thresh, - const int minimum_negative_samples, - const nnvm::Tuple &variances) { - const int num_batches = labels.size(0); - const int num_labels = labels.size(1); - const int label_width = labels.size(2); - const int num_anchors = anchors.size(0); - const int num_classes = cls_preds.size(1); - CHECK_GE(num_batches, 1); - CHECK_GT(num_labels, 2); - CHECK_GE(num_anchors, 1); - CHECK_EQ(variances.ndim(), 4); - - // init ground-truth flags, by checking valid labels - temp_space[1] = 0.f; - DType *gt_flags = temp_space[1].dptr_; - const int num_threads = cuda::kMaxThreadsPerBlock; - dim3 init_thread_dim(num_threads); - dim3 init_block_dim((num_batches * num_labels - 1) / num_threads + 1); - cuda::CheckLaunchParam(init_block_dim, init_thread_dim, "MultiBoxTarget Init"); - cuda::InitGroundTruthFlags<<>>( - gt_flags, labels.dptr_, num_batches, num_labels, label_width); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - - // compute best matches - temp_space[2] = -1.f; - temp_space[3] = -1.f; - DType *anchor_flags = temp_space[2].dptr_; - DType *best_matches = temp_space[3].dptr_; - const DType *overlaps = temp_space[0].dptr_; - cuda::CheckLaunchParam(num_batches, num_threads, "MultiBoxTarget Matching"); - cuda::FindBestMatches<<>>(best_matches, - gt_flags, anchor_flags, overlaps, num_anchors, num_labels); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - - // find good matches with overlap > threshold - if (overlap_threshold > 0) { - cuda::FindGoodMatches<<>>(best_matches, - anchor_flags, overlaps, num_anchors, num_labels, - overlap_threshold); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } - - // do negative mining or not - if (negative_mining_ratio > 0) { - CHECK_GT(negative_mining_thresh, 0); - temp_space[4] = 0; - DType *buffer = temp_space[4].dptr_; - cuda::NegativeMining<<>>(overlaps, - cls_preds.dptr_, anchor_flags, buffer, negative_mining_ratio, - negative_mining_thresh, minimum_negative_samples, - num_anchors, num_labels, num_classes); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } else { - int num_blocks = (num_batches * num_anchors - 1) / num_threads + 1; - cuda::CheckLaunchParam(num_blocks, num_threads, "MultiBoxTarget Negative"); - cuda::UseAllNegatives<<>>(anchor_flags, - num_batches * num_anchors); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } - - cuda::AssignTrainigTargets<<>>( - loc_target.dptr_, loc_mask.dptr_, cls_target.dptr_, anchor_flags, - best_matches, labels.dptr_, anchors.dptr_, num_anchors, num_labels, - label_width, variances[0], variances[1], variances[2], variances[3]); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxTargetParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxTargetOp(param); - }); - return op; -} -} // namespace op -} // namespace mxnet diff --git a/symbol/common.py b/symbol/common.py index dab2b68..12ea718 100644 --- a/symbol/common.py +++ b/symbol/common.py @@ -172,7 +172,7 @@ def multibox_layer(from_layers, num_classes, sizes=[.2, .95], step = (steps[k], steps[k]) else: step = '(-1.0, -1.0)' - anchors = mx.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \ + anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \ clip=clip, name="{}_anchors".format(from_name), steps=step) anchors = mx.symbol.Flatten(data=anchors) anchor_layers.append(anchors) diff --git a/symbol/symbol_vgg16_ssd_300.py b/symbol/symbol_vgg16_ssd_300.py index 7c3baaa..e406746 100644 --- a/symbol/symbol_vgg16_ssd_300.py +++ b/symbol/symbol_vgg16_ssd_300.py @@ -126,7 +126,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ num_channels=num_channels, clip=False, interm_layer=0, steps=steps) - tmp = mx.symbol.MultiBoxTarget( + tmp = mx.contrib.symbol.MultiBoxTarget( *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), @@ -145,7 +145,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t # monitoring training status cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") - det = mx.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") @@ -183,7 +183,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_topk=40 cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ name='cls_prob') - out = mx.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) return out diff --git a/symbol/symbol_vgg16_ssd_512.py b/symbol/symbol_vgg16_ssd_512.py index e712017..e223d42 100644 --- a/symbol/symbol_vgg16_ssd_512.py +++ b/symbol/symbol_vgg16_ssd_512.py @@ -114,7 +114,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t stride=(1,1), act_type="relu", use_batchnorm=False) conv12_1, relu12_1 = conv_act_layer(relu11_2, "12_1", 128, kernel=(1,1), pad=(0,0), \ stride=(1,1), act_type="relu", use_batchnorm=False) - conv12_2, relu12_2 = conv_act_layer(relu12_1, "12_2", 256, kernel=(3,3), pad=(0,0), \ + conv12_2, relu12_2 = conv_act_layer(relu12_1, "12_2", 256, kernel=(4,4), pad=(1,1), \ stride=(1,1), act_type="relu", use_batchnorm=False) # specific parameters for VGG16 network @@ -131,7 +131,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ num_channels=num_channels, clip=False, interm_layer=0, steps=steps) - tmp = mx.symbol.MultiBoxTarget( + tmp = mx.contrib.symbol.MultiBoxTarget( *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), @@ -150,7 +150,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t # monitoring training status cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") - det = mx.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") @@ -188,7 +188,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_topk=40 cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ name='cls_prob') - out = mx.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) return out diff --git a/tools/caffe_converter/README.md b/tools/caffe_converter/README.md index 92cd2aa..17b2b23 100644 --- a/tools/caffe_converter/README.md +++ b/tools/caffe_converter/README.md @@ -4,7 +4,7 @@ Either [Caffe's python package](http://caffe.berkeleyvision.org/installation.html) or [Google protobuf](https://developers.google.com/protocol-buffers/?hl=en) is required. The latter is often much easier to install: -1. We first install the protobuf compiler. If you compiled mxnet with `USE_DIST_KVSTORE = 1` then it is already built. Otherwise, install `protobuf-compiler` by your favor package manager, e.g. `sudo apt-get install protobuf-compiler` for ubuntu and `sudo yum install protobuf-compiler` for redhat/fedora. +1. We first install the protobuf compiler. If you compiled mxnet with `USE_DIST_KVSTORE = 1` then it is already built. Otherwise, install `protobuf-compiler` by your favor package manager, e.g. `sudo apt-get install protobuf-compiler` for ubuntu and `sudo yum install protobuf-compiler` for redhat/fedora. 2. Then install the protobuf's python binding. For example `sudo pip install protobuf` @@ -23,6 +23,7 @@ so we install the bindings first, and then install the corresponding compiler. ### How to use +To convert ssd caffemodels, Use: `python convert_model.py prototxt caffemodel outputprefix` Linux: Use `./run.sh model_name` to download and convert a model. E.g. `./run.sh vgg19` @@ -37,4 +38,4 @@ For example: `python convert_model.py VGG_ILSVRC_16_layers_deploy.prototxt VGG_I * The tool can only work with the L2LayerParameter in Caffe. * Caffe uses a convention for multi-strided pooling output shape inconsistent with MXNet * This importer doesn't handle this problem properly yet - * And example of this failure is importing bvlc_Googlenet. The user needs to add padding to stride-2 pooling to make this work right now. \ No newline at end of file + * And example of this failure is importing bvlc_Googlenet. The user needs to add padding to stride-2 pooling to make this work right now. diff --git a/tools/caffe_converter/convert_model.py b/tools/caffe_converter/convert_model.py index 57e2284..a06b655 100644 --- a/tools/caffe_converter/convert_model.py +++ b/tools/caffe_converter/convert_model.py @@ -124,7 +124,7 @@ def main(): if first_conv and (layer_type == 'Convolution' or layer_type == 4): first_conv = False - model = mx.mod.Module(symbol=prob, label_names=['prob_label', ]) + model = mx.mod.Module(symbol=prob, label_names=None) model.bind(data_shapes=[('data', tuple(input_dim))]) model.init_params(arg_params=arg_params, aux_params={}) diff --git a/tools/caffe_converter/convert_symbol.py b/tools/caffe_converter/convert_symbol.py index eefd72e..63b044a 100644 --- a/tools/caffe_converter/convert_symbol.py +++ b/tools/caffe_converter/convert_symbol.py @@ -248,7 +248,7 @@ def proto2script(proto_file): finput_dim = float(input_dim[2]) step = '(%f, %f)' % (step_h / finput_dim, step_w / finput_dim) assert param.offset == 0.5, "currently only support offset = 0.5" - symbol_string += '%s = mx.symbol.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \ + symbol_string += '%s = mx.contrib.symbol.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \ (name, mapping[layer[i].bottom[0]], sizes, ratios_string, clip, step, name) symbol_string += '%s = mx.symbol.Flatten(data=%s)\n' % (name, name) type_string = 'split' @@ -264,7 +264,7 @@ def proto2script(proto_file): assert param.share_location == True assert param.background_label_id == 0 nms_param = param.nms_param - type_string = 'mx.symbol.MultiBoxDetection' + type_string = 'mx.contrib.symbol.MultiBoxDetection' param_string = "nms_threshold=%f, nms_topk=%d" % \ (nms_param.nms_threshold, nms_param.top_k) if type_string == '': diff --git a/tools/visualize_net.py b/tools/visualize_net.py index 8c71c0a..714806f 100644 --- a/tools/visualize_net.py +++ b/tools/visualize_net.py @@ -1,3 +1,4 @@ +from __future__ import print_function import find_mxnet import mxnet as mx import importlib @@ -24,4 +25,4 @@ a.render("ssd_" + args.network) else: net = importlib.import_module("symbol_" + args.network).get_symbol_train(args.num_classes) - print net.tojson() + print(net.tojson()) diff --git a/train.py b/train.py index 4927dea..fcd5fb9 100644 --- a/train.py +++ b/train.py @@ -114,7 +114,7 @@ def parse_class_names(args): # start training train_net(args.network, args.train_path, args.num_class, args.batch_size, - args.data_shape, (args.mean_r, args.mean_g, args.mean_b), + args.data_shape, [args.mean_r, args.mean_g, args.mean_b], args.resume, args.finetune, args.pretrained, args.epoch, args.prefix, ctx, args.begin_epoch, args.end_epoch, args.frequent, args.learning_rate, args.momentum, args.weight_decay, diff --git a/train/train_net.py b/train/train_net.py index 6be080b57a3bf870ec8080baf8f5bd3c81e7b452..2e389d314804ea5186c96c8c6c5af337e8c3d93f 100644 GIT binary patch delta 12 TcmbQ|JKuMMHTz~O_7r&lAXo%J delta 12 TcmbR5JI{B6HTz~8_GEbgAYcSS