Skip to content

Commit

Permalink
optimize code and reformat code.
Browse files Browse the repository at this point in the history
  • Loading branch information
yangze0930 committed Aug 18, 2019
1 parent dc5c41a commit cac17da
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 189 deletions.
3 changes: 2 additions & 1 deletion src/reppoints_assigner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
from .point_assigner import PointAssigner

__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner'
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner'
]
66 changes: 37 additions & 29 deletions src/reppoints_assigner/point_assigner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from .base_assigner import BaseAssigner
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


class PointAssigner(BaseAssigner):
Expand All @@ -20,20 +20,19 @@ def __init__(self, scale=4, pos_num=3):
self.pos_num = pos_num

def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.
"""Assign gt to points.
This method assign a gt bbox to every point, each bbox
This method assign a gt bbox to every points set, each points set
will be assigned with 0, or a positive number.
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every points to 0
2. for each gt box, we find the k most closest points to the
box center and assign the gt bbox to those points, we also record
the minimum distance from each point to the closest gt box. When we
assign the bbox to the points, we check whether its distance to the
points is closest.
2. A point is assigned to some gt bbox if
(i) the point is within the k closest points to the gt bbox
(ii) the distance between this point and the gt is smaller than
other gt bboxes
Args:
points (Tensor): points to be assigned, shape(n, 3) while last
Expand All @@ -48,28 +47,26 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""
if points.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
points_range = torch.arange(points.shape[0])
points_xy = points[:, :2]
points_stride = points[:, 2]
points_lvl = torch.log2(points_stride).int() # [3...,4...,5...,6...,7...]
points_lvl = torch.log2(
points_stride).int() # [3...,4...,5...,6...,7...]
lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
num_gts, num_points = gt_bboxes.shape[0], points.shape[0]

# assign gt box
gt_bboxes_x = 0.5 * (gt_bboxes[:, 0] + gt_bboxes[:, 2])
gt_bboxes_y = 0.5 * (gt_bboxes[:, 1] + gt_bboxes[:, 3])
gt_bboxes_xy = torch.stack([gt_bboxes_x, gt_bboxes_y], -1)
gt_bboxes_w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
gt_bboxes_h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
gt_bboxes_wh = torch.stack([gt_bboxes_w, gt_bboxes_h], -1)
gt_bboxes_wh = torch.clamp(gt_bboxes_wh, min=1e-6)
gt_bboxes_lvl = (0.5 * (torch.log2(gt_bboxes_w / self.scale) + torch.log2(gt_bboxes_h / self.scale))).int()
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
scale = self.scale
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)

# stores the assigned gt index of each point
assigned_gt_inds = points.new_zeros((num_points,), dtype=torch.long)
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
# stores the assigned gt dist (to this point) of each point
assigned_gt_dist = points.new_full((num_points,), float('inf'))
assigned_gt_dist = points.new_full((num_points, ), float('inf'))
points_range = torch.arange(points.shape[0])

for idx in range(num_gts):
gt_lvl = gt_bboxes_lvl[idx]
Expand All @@ -82,20 +79,32 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
gt_point = gt_bboxes_xy[[idx], :]
# get width and height of gt
gt_wh = gt_bboxes_wh[[idx], :]
# compute the distance between gt center and all points in this level
points_gt_dist = ((lvl_points-gt_point)/gt_wh).norm(dim=1)
# compute the distance between gt center and
# all points in this level
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
# find the nearest k points to gt center in this level
min_dist, min_dist_index = torch.topk(-points_gt_dist, self.pos_num)
min_dist, min_dist_index = torch.topk(
points_gt_dist, self.pos_num, largest=False)
# the index of nearest k points to gt center in this level
min_dist_points_index = points_index[min_dist_index]
less_than_recorded_index = min_dist < assigned_gt_dist[min_dist_points_index]
min_dist_points_index = min_dist_points_index[less_than_recorded_index]
# The less_than_recorded_index stores the index
# of min_dist that is less then the assigned_gt_dist. Where
# assigned_gt_dist stores the dist from previous assigned gt
# (if exist) to each point.
less_than_recorded_index = min_dist < assigned_gt_dist[
min_dist_points_index]
# The min_dist_points_index stores the index of points satisfy:
# (1) it is k nearest to current gt center in this level.
# (2) it is closer to current gt center than other gt center.
min_dist_points_index = min_dist_points_index[
less_than_recorded_index]
# assign the result
assigned_gt_inds[min_dist_points_index] = idx+1
assigned_gt_dist[min_dist_points_index] = min_dist[less_than_recorded_index]
assigned_gt_inds[min_dist_points_index] = idx + 1
assigned_gt_dist[min_dist_points_index] = min_dist[
less_than_recorded_index]

if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_zeros((num_points,))
assigned_labels = assigned_gt_inds.new_zeros((num_points, ))
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
Expand All @@ -105,4 +114,3 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):

return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)

2 changes: 1 addition & 1 deletion src/reppoints_detector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
from .single_stage import SingleStageDetector
from .two_stage import TwoStageDetector
from .reppoints_detector import RepPointsDetector

__all__ = [
'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
Expand Down
22 changes: 15 additions & 7 deletions src/reppoints_detector/reppoints_detector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import torch
from .single_stage import SingleStageDetector

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from ..registry import DETECTORS
from mmdet.core import bbox2result, multiclass_nms, bbox_mapping_back
from .single_stage import SingleStageDetector


@DETECTORS.register_module
class RepPointsDetector(SingleStageDetector):
"""RepPoints: Point Set Representation for Object Detection.
This detector is the implementation of:
- RepPoints detector (https://arxiv.org/pdf/1904.11490)
"""

def __init__(self,
backbone,
Expand All @@ -14,8 +20,9 @@ def __init__(self,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(RepPointsDetector, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)

def merge_aug_results(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Expand Down Expand Up @@ -60,9 +67,10 @@ def aug_test(self, imgs, img_metas, rescale=False):
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_results(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.test_cfg.score_thr,
self.test_cfg.nms, self.test_cfg.max_per_img)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)

if rescale:
_det_bboxes = det_bboxes
Expand Down
5 changes: 3 additions & 2 deletions src/reppoints_generator/point_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class PointGenerator(object):

def _meshgrid(self, x, y, row_major=True):
Expand All @@ -15,7 +16,7 @@ def grid_points(self, featmap_size, stride=16, device='cuda'):
shift_x = torch.arange(0., feat_w, device=device) * stride
shift_y = torch.arange(0., feat_h, device=device) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
stride = shift_x.new_full((shift_xx.shape[0],), stride)
stride = shift_x.new_full((shift_xx.shape[0], ), stride)
shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
all_points = shifts.to(device)
return all_points
Expand All @@ -30,4 +31,4 @@ def valid_flags(self, featmap_size, valid_size, device='cuda'):
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid
return valid
51 changes: 27 additions & 24 deletions src/reppoints_generator/point_target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from ..bbox import assign_and_sample, build_assigner, PseudoSampler
from ..bbox import PseudoSampler, assign_and_sample, build_assigner
from ..utils import multi_apply


Expand All @@ -14,7 +14,7 @@ def point_target(proposals_list,
label_channels=1,
sampling=True,
unmap_outputs=True):
"""Compute refinement and classification targets for points.
"""Compute corresponding GT box and classification targets for proposals.
Args:
points_list (list[list]): Multi level points of each image.
Expand Down Expand Up @@ -43,29 +43,31 @@ def point_target(proposals_list,
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights,
pos_inds_list, neg_inds_list) = multi_apply(
point_target_single,
proposals_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
cfg=cfg,
label_channels=label_channels,
sampling=sampling,
unmap_outputs=unmap_outputs)
(all_labels, all_label_weights, all_bbox_gt, all_proposals,
all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply(
point_target_single,
proposals_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
cfg=cfg,
label_channels=label_channels,
sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid points
if any([labels is None for labels in all_labels]):
return None
# sampled points of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
labels_list = images_to_levels(all_labels, num_level_proposals)
label_weights_list = images_to_levels(all_label_weights, num_level_proposals)
label_weights_list = images_to_levels(all_label_weights,
num_level_proposals)
bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
proposals_list = images_to_levels(all_proposals, num_level_proposals)
proposal_weights_list = images_to_levels(all_proposal_weights, num_level_proposals)
proposal_weights_list = images_to_levels(all_proposal_weights,
num_level_proposals)
return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
proposal_weights_list, num_total_pos, num_total_neg)

Expand Down Expand Up @@ -96,8 +98,8 @@ def point_target_single(flat_proposals,
unmap_outputs=True):
inside_flags = valid_flags
if not inside_flags.any():
return (None,) * 7
# assign gt and sample points
return (None, ) * 7
# assign gt and sample proposals
proposals = flat_proposals[inside_flags, :]

if sampling:
Expand Down Expand Up @@ -136,27 +138,28 @@ def point_target_single(flat_proposals,
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0

# map up to original set of grids
# map up to original set of proposals
if unmap_outputs:
num_total_proposals = flat_proposals.size(0)
labels = unmap(labels, num_total_proposals, inside_flags)
label_weights = unmap(label_weights, num_total_proposals, inside_flags)
bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags)
proposals_weightss = unmap(proposals_weights, num_total_proposals, inside_flags)
proposals_weights = unmap(proposals_weights, num_total_proposals,
inside_flags)

return (labels, label_weights, bbox_gt, pos_proposals, proposals_weightss, pos_inds,
neg_inds)
return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights,
pos_inds, neg_inds)


def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if data.dim() == 1:
ret = data.new_full((count,), fill)
ret = data.new_full((count, ), fill)
ret[inds] = data
else:
new_size = (count,) + data.size()[1:]
new_size = (count, ) + data.size()[1:]
ret = data.new_full(new_size, fill)
ret[inds, :] = data
return ret
2 changes: 1 addition & 1 deletion src/reppoints_head/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .reppoints_head import RepPointsHead

__all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
Expand Down
Loading

0 comments on commit cac17da

Please sign in to comment.