Skip to content

Commit

Permalink
fix loss in multibox_target
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Mar 26, 2017
1 parent 02bc851 commit 27cf0b5
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 13 deletions.
10 changes: 4 additions & 6 deletions operator/multibox_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ inline void MultiBoxTargetForward(const Tensor<cpu, 2, DType> &loc_target,
anchor_flags[j] == -1) {
// calcuate class predictions
DType max_val = p_cls_preds[j];
DType max_val_pos = max_val;
for (int k = 1; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
if (tmp > max_val_pos) max_val_pos = tmp;
if (tmp > max_val) max_val = tmp;
}
DType sum = 0.f;
for (int k = 0; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
sum += std::exp(tmp - max_val);
}
max_val_pos = std::exp(max_val_pos - max_val) / sum;
// loss should be -log(x), but value does not matter, so skip log
DType loss = -max_val_pos;
temp.push_back(SortElemDescend(loss, j));
DType prob = std::exp(p_cls_preds[j] - max_val) / sum;
// loss should be -log(x), but value does not matter, skip log
temp.push_back(SortElemDescend(-prob, j));
}
} // end iterate anchors

Expand Down
7 changes: 3 additions & 4 deletions operator/multibox_target.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,24 @@ __global__ void NegativeMining(const DType *overlaps, const DType *cls_preds,
if (anchor_flags[i] < 0) {
// compute max class prediction score
DType max_val = cls_preds[i];
DType max_val_pos = max_val; // regarding background
for (int j = 1; j < num_classes; ++j) {
DType temp = cls_preds[i + num_anchors * j];
if (temp > max_val_pos) max_val_pos = temp;
if (temp > max_val) max_val = temp;
}
DType sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
DType temp = cls_preds[i + num_anchors * j];
sum += exp(temp - max_val);
}
max_val_pos = exp(max_val_pos - max_val) / sum;
DType prob = exp(cls_preds[i] - max_val) / sum;
DType max_iou = -1.f;
for (int j = 0; j < num_labels; ++j) {
DType temp = overlaps[i * num_labels + j];
if (temp > max_iou) max_iou = temp;
}
if (max_iou < negative_mining_thresh) {
// only do it for anchors with iou < thresh
buffer[i] = -max_val_pos; // -log(x) actually, but value does not matter
buffer[i] = -prob; // -log(x) actually, but value does not matter
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion symbol/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def multibox_layer(from_layers, num_classes, sizes=[.2, .95],
else:
step = '(-1.0, -1.0)'
anchors = mx.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \
clip=clip, name="{}_anchors".format(from_name), steps=steps)
clip=clip, name="{}_anchors".format(from_name), steps=step)
anchors = mx.symbol.Flatten(data=anchors)
anchor_layers.append(anchors)

Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def parse_args():
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--val-list', dest='val_list', help='validation list to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='ssd_300',
choices=['vgg16_reduced', 'ssd_300'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='training batch size')
parser.add_argument('--resume', dest='resume', type=int, default=-1,
Expand Down
Binary file modified train/train_net.py
Binary file not shown.

0 comments on commit 27cf0b5

Please sign in to comment.