diff --git a/src/reppoints_assigner/__init__.py b/src/reppoints_assigner/__init__.py index 05d52f6..93eebb7 100644 --- a/src/reppoints_assigner/__init__.py +++ b/src/reppoints_assigner/__init__.py @@ -5,5 +5,6 @@ from .point_assigner import PointAssigner __all__ = [ - 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner' + 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', + 'PointAssigner' ] diff --git a/src/reppoints_assigner/point_assigner.py b/src/reppoints_assigner/point_assigner.py index 610acb7..fe81e7d 100644 --- a/src/reppoints_assigner/point_assigner.py +++ b/src/reppoints_assigner/point_assigner.py @@ -1,7 +1,7 @@ import torch -from .base_assigner import BaseAssigner from .assign_result import AssignResult +from .base_assigner import BaseAssigner class PointAssigner(BaseAssigner): @@ -20,20 +20,19 @@ def __init__(self, scale=4, pos_num=3): self.pos_num = pos_num def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): - """Assign gt to bboxes. + """Assign gt to points. - This method assign a gt bbox to every point, each bbox + This method assign a gt bbox to every points set, each points set will be assigned with 0, or a positive number. 0 means negative sample, positive number is the index (1-based) of assigned gt. The assignment is done in following steps, the order matters. 1. assign every points to 0 - 2. for each gt box, we find the k most closest points to the - box center and assign the gt bbox to those points, we also record - the minimum distance from each point to the closest gt box. When we - assign the bbox to the points, we check whether its distance to the - points is closest. + 2. A point is assigned to some gt bbox if + (i) the point is within the k closest points to the gt bbox + (ii) the distance between this point and the gt is smaller than + other gt bboxes Args: points (Tensor): points to be assigned, shape(n, 3) while last @@ -48,28 +47,26 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): """ if points.shape[0] == 0 or gt_bboxes.shape[0] == 0: raise ValueError('No gt or bboxes') - points_range = torch.arange(points.shape[0]) points_xy = points[:, :2] points_stride = points[:, 2] - points_lvl = torch.log2(points_stride).int() # [3...,4...,5...,6...,7...] + points_lvl = torch.log2( + points_stride).int() # [3...,4...,5...,6...,7...] lvl_min, lvl_max = points_lvl.min(), points_lvl.max() num_gts, num_points = gt_bboxes.shape[0], points.shape[0] # assign gt box - gt_bboxes_x = 0.5 * (gt_bboxes[:, 0] + gt_bboxes[:, 2]) - gt_bboxes_y = 0.5 * (gt_bboxes[:, 1] + gt_bboxes[:, 3]) - gt_bboxes_xy = torch.stack([gt_bboxes_x, gt_bboxes_y], -1) - gt_bboxes_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] - gt_bboxes_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] - gt_bboxes_wh = torch.stack([gt_bboxes_w, gt_bboxes_h], -1) - gt_bboxes_wh = torch.clamp(gt_bboxes_wh, min=1e-6) - gt_bboxes_lvl = (0.5 * (torch.log2(gt_bboxes_w / self.scale) + torch.log2(gt_bboxes_h / self.scale))).int() + gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 + gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) + scale = self.scale + gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + + torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) # stores the assigned gt index of each point - assigned_gt_inds = points.new_zeros((num_points,), dtype=torch.long) + assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) # stores the assigned gt dist (to this point) of each point - assigned_gt_dist = points.new_full((num_points,), float('inf')) + assigned_gt_dist = points.new_full((num_points, ), float('inf')) + points_range = torch.arange(points.shape[0]) for idx in range(num_gts): gt_lvl = gt_bboxes_lvl[idx] @@ -82,20 +79,32 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): gt_point = gt_bboxes_xy[[idx], :] # get width and height of gt gt_wh = gt_bboxes_wh[[idx], :] - # compute the distance between gt center and all points in this level - points_gt_dist = ((lvl_points-gt_point)/gt_wh).norm(dim=1) + # compute the distance between gt center and + # all points in this level + points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) # find the nearest k points to gt center in this level - min_dist, min_dist_index = torch.topk(-points_gt_dist, self.pos_num) + min_dist, min_dist_index = torch.topk( + points_gt_dist, self.pos_num, largest=False) # the index of nearest k points to gt center in this level min_dist_points_index = points_index[min_dist_index] - less_than_recorded_index = min_dist < assigned_gt_dist[min_dist_points_index] - min_dist_points_index = min_dist_points_index[less_than_recorded_index] + # The less_than_recorded_index stores the index + # of min_dist that is less then the assigned_gt_dist. Where + # assigned_gt_dist stores the dist from previous assigned gt + # (if exist) to each point. + less_than_recorded_index = min_dist < assigned_gt_dist[ + min_dist_points_index] + # The min_dist_points_index stores the index of points satisfy: + # (1) it is k nearest to current gt center in this level. + # (2) it is closer to current gt center than other gt center. + min_dist_points_index = min_dist_points_index[ + less_than_recorded_index] # assign the result - assigned_gt_inds[min_dist_points_index] = idx+1 - assigned_gt_dist[min_dist_points_index] = min_dist[less_than_recorded_index] + assigned_gt_inds[min_dist_points_index] = idx + 1 + assigned_gt_dist[min_dist_points_index] = min_dist[ + less_than_recorded_index] if gt_labels is not None: - assigned_labels = assigned_gt_inds.new_zeros((num_points,)) + assigned_labels = assigned_gt_inds.new_zeros((num_points, )) pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() if pos_inds.numel() > 0: assigned_labels[pos_inds] = gt_labels[ @@ -105,4 +114,3 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) - diff --git a/src/reppoints_detector/__init__.py b/src/reppoints_detector/__init__.py index a374298..189c823 100644 --- a/src/reppoints_detector/__init__.py +++ b/src/reppoints_detector/__init__.py @@ -8,11 +8,11 @@ from .htc import HybridTaskCascade from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN +from .reppoints_detector import RepPointsDetector from .retinanet import RetinaNet from .rpn import RPN from .single_stage import SingleStageDetector from .two_stage import TwoStageDetector -from .reppoints_detector import RepPointsDetector __all__ = [ 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', diff --git a/src/reppoints_detector/reppoints_detector.py b/src/reppoints_detector/reppoints_detector.py index 02de73a..1e35116 100644 --- a/src/reppoints_detector/reppoints_detector.py +++ b/src/reppoints_detector/reppoints_detector.py @@ -1,11 +1,17 @@ import torch -from .single_stage import SingleStageDetector + +from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms from ..registry import DETECTORS -from mmdet.core import bbox2result, multiclass_nms, bbox_mapping_back +from .single_stage import SingleStageDetector @DETECTORS.register_module class RepPointsDetector(SingleStageDetector): + """RepPoints: Point Set Representation for Object Detection. + + This detector is the implementation of: + - RepPoints detector (https://arxiv.org/pdf/1904.11490) + """ def __init__(self, backbone, @@ -14,8 +20,9 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(RepPointsDetector, self).__init__(backbone, neck, bbox_head, train_cfg, - test_cfg, pretrained) + super(RepPointsDetector, + self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained) def merge_aug_results(self, aug_bboxes, aug_scores, img_metas): """Merge augmented detection bboxes and scores. @@ -60,9 +67,10 @@ def aug_test(self, imgs, img_metas, rescale=False): # after merging, bboxes will be rescaled to the original image size merged_bboxes, merged_scores = self.merge_aug_results( aug_bboxes, aug_scores, img_metas) - det_bboxes, det_labels = multiclass_nms( - merged_bboxes, merged_scores, self.test_cfg.score_thr, - self.test_cfg.nms, self.test_cfg.max_per_img) + det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, + self.test_cfg.score_thr, + self.test_cfg.nms, + self.test_cfg.max_per_img) if rescale: _det_bboxes = det_bboxes diff --git a/src/reppoints_generator/point_generator.py b/src/reppoints_generator/point_generator.py index 2ac0bf0..c1a34dd 100644 --- a/src/reppoints_generator/point_generator.py +++ b/src/reppoints_generator/point_generator.py @@ -1,5 +1,6 @@ import torch + class PointGenerator(object): def _meshgrid(self, x, y, row_major=True): @@ -15,7 +16,7 @@ def grid_points(self, featmap_size, stride=16, device='cuda'): shift_x = torch.arange(0., feat_w, device=device) * stride shift_y = torch.arange(0., feat_h, device=device) * stride shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) - stride = shift_x.new_full((shift_xx.shape[0],), stride) + stride = shift_x.new_full((shift_xx.shape[0], ), stride) shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) all_points = shifts.to(device) return all_points @@ -30,4 +31,4 @@ def valid_flags(self, featmap_size, valid_size, device='cuda'): valid_y[:valid_h] = 1 valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid = valid_xx & valid_yy - return valid \ No newline at end of file + return valid diff --git a/src/reppoints_generator/point_target.py b/src/reppoints_generator/point_target.py index a9370bf..1ab8d02 100644 --- a/src/reppoints_generator/point_target.py +++ b/src/reppoints_generator/point_target.py @@ -1,6 +1,6 @@ import torch -from ..bbox import assign_and_sample, build_assigner, PseudoSampler +from ..bbox import PseudoSampler, assign_and_sample, build_assigner from ..utils import multi_apply @@ -14,7 +14,7 @@ def point_target(proposals_list, label_channels=1, sampling=True, unmap_outputs=True): - """Compute refinement and classification targets for points. + """Compute corresponding GT box and classification targets for proposals. Args: points_list (list[list]): Multi level points of each image. @@ -43,18 +43,18 @@ def point_target(proposals_list, gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] - (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, - pos_inds_list, neg_inds_list) = multi_apply( - point_target_single, - proposals_list, - valid_flag_list, - gt_bboxes_list, - gt_bboxes_ignore_list, - gt_labels_list, - cfg=cfg, - label_channels=label_channels, - sampling=sampling, - unmap_outputs=unmap_outputs) + (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply( + point_target_single, + proposals_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + cfg=cfg, + label_channels=label_channels, + sampling=sampling, + unmap_outputs=unmap_outputs) # no valid points if any([labels is None for labels in all_labels]): return None @@ -62,10 +62,12 @@ def point_target(proposals_list, num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) labels_list = images_to_levels(all_labels, num_level_proposals) - label_weights_list = images_to_levels(all_label_weights, num_level_proposals) + label_weights_list = images_to_levels(all_label_weights, + num_level_proposals) bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) proposals_list = images_to_levels(all_proposals, num_level_proposals) - proposal_weights_list = images_to_levels(all_proposal_weights, num_level_proposals) + proposal_weights_list = images_to_levels(all_proposal_weights, + num_level_proposals) return (labels_list, label_weights_list, bbox_gt_list, proposals_list, proposal_weights_list, num_total_pos, num_total_neg) @@ -96,8 +98,8 @@ def point_target_single(flat_proposals, unmap_outputs=True): inside_flags = valid_flags if not inside_flags.any(): - return (None,) * 7 - # assign gt and sample points + return (None, ) * 7 + # assign gt and sample proposals proposals = flat_proposals[inside_flags, :] if sampling: @@ -136,27 +138,28 @@ def point_target_single(flat_proposals, if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 - # map up to original set of grids + # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) labels = unmap(labels, num_total_proposals, inside_flags) label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags) - proposals_weightss = unmap(proposals_weights, num_total_proposals, inside_flags) + proposals_weights = unmap(proposals_weights, num_total_proposals, + inside_flags) - return (labels, label_weights, bbox_gt, pos_proposals, proposals_weightss, pos_inds, - neg_inds) + return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, + pos_inds, neg_inds) def unmap(data, count, inds, fill=0): """ Unmap a subset of item (data) back to the original set of items (of size count) """ if data.dim() == 1: - ret = data.new_full((count,), fill) + ret = data.new_full((count, ), fill) ret[inds] = data else: - new_size = (count,) + data.size()[1:] + new_size = (count, ) + data.size()[1:] ret = data.new_full(new_size, fill) ret[inds, :] = data return ret diff --git a/src/reppoints_head/__init__.py b/src/reppoints_head/__init__.py index 0f1daf7..5df25d0 100644 --- a/src/reppoints_head/__init__.py +++ b/src/reppoints_head/__init__.py @@ -3,10 +3,10 @@ from .ga_retina_head import GARetinaHead from .ga_rpn_head import GARPNHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead +from .reppoints_head import RepPointsHead from .retina_head import RetinaHead from .rpn_head import RPNHead from .ssd_head import SSDHead -from .reppoints_head import RepPointsHead __all__ = [ 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', diff --git a/src/reppoints_head/reppoints_head.py b/src/reppoints_head/reppoints_head.py index 50e4bf9..1ce7abd 100644 --- a/src/reppoints_head/reppoints_head.py +++ b/src/reppoints_head/reppoints_head.py @@ -4,13 +4,13 @@ import torch import torch.nn as nn from mmcv.cnn import normal_init -from mmdet.ops import DeformConv -from mmdet.core import (PointGenerator, point_target, - multi_apply, multiclass_nms) +from mmdet.core import (PointGenerator, multi_apply, multiclass_nms, + point_target) +from mmdet.ops import DeformConv from ..builder import build_loss from ..registry import HEADS -from ..utils import bias_init_with_prob, ConvModule +from ..utils import ConvModule, bias_init_with_prob @HEADS.register_module @@ -22,12 +22,16 @@ class RepPointsHead(nn.Module): feat_channels (int): Number of channels of the feature map. point_feat_channels (int): Number of channels of points features. stacked_convs (int): How many conv layers are used. - gradient_mul (float): The multiplier to gradients from points refinement and recognition. + gradient_mul (float): The multiplier to gradients from + points refinement and recognition. point_strides (Iterable): points strides. point_base_scale (int): bbox scale for assigning labels. loss_cls (dict): Config of classification loss. loss_bbox_init (dict): Config of initial points loss. loss_bbox_refine (dict): Config of points loss in refinement. + use_grid_points (bool): If we use bounding box representation, the + reppoints is represented as grid points on the bounding box. + center_init (bool): Whether to use center point assignment. transform_method (str): The methods to transform RepPoints to bbox. """ # noqa: W605 @@ -43,12 +47,20 @@ def __init__(self, point_base_scale=4, conv_cfg=None, norm_cfg=None, - loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_bbox_init=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), - loss_bbox_refine=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), + loss_bbox_refine=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), use_grid_points=False, center_init=True, - transform_method='moment'): + transform_method='moment', + moment_mul=0.01): super(RepPointsHead, self).__init__() self.in_channels = in_channels self.num_classes = num_classes @@ -70,35 +82,34 @@ def __init__(self, self.center_init = center_init self.transform_method = transform_method if self.transform_method == 'moment': - self.moment_transfer = nn.Parameter(data=torch.zeros(2), requires_grad=True) + self.moment_transfer = nn.Parameter( + data=torch.zeros(2), requires_grad=True) + self.moment_mul = moment_mul if self.use_sigmoid_cls: self.cls_out_channels = self.num_classes - 1 else: self.cls_out_channels = self.num_classes - self.point_generators = [] - for _ in self.point_strides: - self.point_generators.append(PointGenerator()) - self._init_layers() - - def _init_dcn_offset(self, num_points): + self.point_generators = [PointGenerator() for _ in self.point_strides] + # we use deformable conv to extract points features self.dcn_kernel = int(np.sqrt(num_points)) self.dcn_pad = int((self.dcn_kernel - 1) / 2) - assert self.dcn_kernel * self.dcn_kernel == num_points, "The points number should be a square number." - assert self.dcn_kernel % 2 == 1, "The points number should be an odd square number." - dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float) + assert self.dcn_kernel * self.dcn_kernel == num_points, \ + "The points number should be a square number." + assert self.dcn_kernel % 2 == 1, \ + "The points number should be an odd square number." + dcn_base = np.arange(-self.dcn_pad, + self.dcn_pad + 1).astype(np.float64) dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) dcn_base_x = np.tile(dcn_base, self.dcn_kernel) - dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1)) + dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( + (-1)) self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) + self._init_layers() def _init_layers(self): - self._init_dcn_offset(self.num_points) self.relu = nn.ReLU(inplace=True) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() - self.reppoints_cls = nn.ModuleList() - self.reppoints_pts_init = nn.ModuleList() - self.reppoints_pts_refine = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( @@ -120,14 +131,22 @@ def _init_layers(self): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points - self.reppoints_cls_conv = DeformConv(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, - self.dcn_pad) - self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, self.cls_out_channels, 1, 1, 0) - self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, self.point_feat_channels, 3, 1, 1) - self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0) - self.reppoints_pts_refine_conv = DeformConv(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, + self.reppoints_cls_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, self.dcn_pad) + self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, + self.cls_out_channels, 1, 1, 0) + self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, + self.point_feat_channels, 3, + 1, 1) + self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + self.reppoints_pts_refine_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, self.dcn_pad) - self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0) + self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) def init_weights(self): for m in self.cls_convs: @@ -142,77 +161,101 @@ def init_weights(self): normal_init(self.reppoints_pts_refine_conv, std=0.01) normal_init(self.reppoints_pts_refine_out, std=0.01) - def transform_box(self, pts, y_first=True): + def points2bbox(self, pts, y_first=True): + """ + Converting the points set into bounding box. + :param pts: the input points sets (fields), each points + set (fields) is represented as 2n scalar. + :param y_first: if y_fisrt=True, the point set is represented as + [y1, x1, y2, x2 ... yn, xn], otherwise the point set is + represented as [x1, y1, x2, y2 ... xn, yn]. + :return: each points set is converting to a bbox [x1, y1, x2, y2]. + """ + pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) + pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, + ...] + pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, + ...] if self.transform_method == 'minmax': - pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) - pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...] - pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...] bbox_left = pts_x.min(dim=1, keepdim=True)[0] bbox_right = pts_x.max(dim=1, keepdim=True)[0] bbox_up = pts_y.min(dim=1, keepdim=True)[0] bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] - bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1) + bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], + dim=1) elif self.transform_method == 'partial_minmax': - pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) - pts_reshape = pts_reshape[:, :4, ...] - pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...] - pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...] + pts_y = pts_y[:, :4, ...] + pts_x = pts_x[:, :4, ...] bbox_left = pts_x.min(dim=1, keepdim=True)[0] bbox_right = pts_x.max(dim=1, keepdim=True)[0] bbox_up = pts_y.min(dim=1, keepdim=True)[0] bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] - bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1) + bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], + dim=1) elif self.transform_method == 'moment': - pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) - pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...] - pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...] pts_y_mean = pts_y.mean(dim=1, keepdim=True) pts_x_mean = pts_x.mean(dim=1, keepdim=True) pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True) pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True) - moment_transfer = self.moment_transfer * 0.01 + self.moment_transfer.detach() * 0.99 + moment_transfer = (self.moment_transfer * self.moment_mul) + ( + self.moment_transfer.detach() * (1 - self.moment_mul)) moment_width_transfer = moment_transfer[0] moment_height_transfer = moment_transfer[1] half_width = pts_x_std * torch.exp(moment_width_transfer) half_height = pts_y_std * torch.exp(moment_height_transfer) - bbox = torch.cat([pts_x_mean - half_width, pts_y_mean - half_height, - pts_x_mean + half_width, pts_y_mean + half_height], dim=1) + bbox = torch.cat([ + pts_x_mean - half_width, pts_y_mean - half_height, + pts_x_mean + half_width, pts_y_mean + half_height + ], + dim=1) else: raise NotImplementedError return bbox def gen_grid_from_reg(self, reg, previous_boxes): + """ + Base on the previous bboxes and regression values, we compute the + regressed bboxes and generate the grids on the bboxes. + :param reg: the regression value to previous bboxes. + :param previous_boxes: previous bboxes. + :return: generate grids on the regressed bboxes. + """ b, _, h, w = reg.shape - tx = reg[:, [0], ...] - ty = reg[:, [1], ...] - tw = reg[:, [2], ...] - th = reg[:, [3], ...] - bx = (previous_boxes[:, [0], ...] + previous_boxes[:, [2], ...]) / 2. - by = (previous_boxes[:, [1], ...] + previous_boxes[:, [3], ...]) / 2. - bw = (previous_boxes[:, [2], ...] - previous_boxes[:, [0], ...]).clamp(min=1e-6) - bh = (previous_boxes[:, [3], ...] - previous_boxes[:, [1], ...]).clamp(min=1e-6) - grid_left = bx + bw * tx - 0.5 * bw * torch.exp(tw) - grid_width = bw * torch.exp(tw) - grid_up = by + bh * ty - 0.5 * bh * torch.exp(th) - grid_height = bh * torch.exp(th) - intervel = torch.linspace(0., 1., self.dcn_kernel).view(1, self.dcn_kernel, 1, 1).type_as(reg) + bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2. + bwh = (previous_boxes[:, 2:, ...] - + previous_boxes[:, :2, ...]).clamp(min=1e-6) + grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp( + reg[:, 2:, ...]) + grid_wh = bwh * torch.exp(reg[:, 2:, ...]) + grid_left = grid_topleft[:, [0], ...] + grid_top = grid_topleft[:, [1], ...] + grid_width = grid_wh[:, [0], ...] + grid_height = grid_wh[:, [1], ...] + intervel = torch.linspace(0., 1., self.dcn_kernel).view( + 1, self.dcn_kernel, 1, 1).type_as(reg) grid_x = grid_left + grid_width * intervel grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1) grid_x = grid_x.view(b, -1, h, w) - grid_y = grid_up + grid_height * intervel + grid_y = grid_top + grid_height * intervel grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1) grid_y = grid_y.view(b, -1, h, w) grid_yx = torch.stack([grid_y, grid_x], dim=2) grid_yx = grid_yx.view(b, -1, h, w) - regressed_bbox = torch.cat([grid_left, grid_up, grid_left + grid_width, grid_up + grid_height], 1) + regressed_bbox = torch.cat([ + grid_left, grid_top, grid_left + grid_width, grid_top + grid_height + ], 1) return grid_yx, regressed_bbox def forward_single(self, x): dcn_base_offset = self.dcn_base_offset.type_as(x) + # If we use center_init, the initial reppoints is from center points. + # If we use bounding bbox representation, the initial reppoints is + # from regular grid placed on a pre-defined bbox. if self.use_grid_points or not self.center_init: scale = self.point_base_scale / 2 points_init = dcn_base_offset / dcn_base_offset.max() * scale - bbox_init = torch.tensor([-scale, -scale, scale, scale]).view(1, 4, 1, 1).type_as(x) + bbox_init = x.new_tensor([-scale, -scale, scale, + scale]).view(1, 4, 1, 1) else: points_init = 0 cls_feat = x @@ -222,19 +265,24 @@ def forward_single(self, x): for reg_conv in self.reg_convs: pts_feat = reg_conv(pts_feat) # initialize reppoints - pts_out_init = self.reppoints_pts_init_out(self.relu(self.reppoints_pts_init_conv(pts_feat))) + pts_out_init = self.reppoints_pts_init_out( + self.relu(self.reppoints_pts_init_conv(pts_feat))) if self.use_grid_points: - pts_out_init, bbox_out_init = self.gen_grid_from_reg(pts_out_init, bbox_init.detach()) + pts_out_init, bbox_out_init = self.gen_grid_from_reg( + pts_out_init, bbox_init.detach()) else: pts_out_init = pts_out_init + points_init # refine and classify reppoints - pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init + pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( + ) + self.gradient_mul * pts_out_init dcn_offset = pts_out_init_grad_mul - dcn_base_offset - cls_out = self.reppoints_cls_out(self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) - pts_out_refine = self.reppoints_pts_refine_out(self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) + cls_out = self.reppoints_cls_out( + self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) + pts_out_refine = self.reppoints_pts_refine_out( + self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) if self.use_grid_points: - bbox_out_init = self.transform_box(pts_out_init) - pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(pts_out_refine, bbox_out_init.detach()) + pts_out_refine, bbox_out_refine = self.gen_grid_from_reg( + pts_out_refine, bbox_out_init.detach()) else: pts_out_refine = pts_out_refine + pts_out_init.detach() return cls_out, pts_out_init, pts_out_refine @@ -259,9 +307,11 @@ def get_points(self, featmap_sizes, img_metas): # points center for one time multi_level_points = [] for i in range(num_levels): - points = self.point_generators[i].grid_points(featmap_sizes[i], self.point_strides[i]) + points = self.point_generators[i].grid_points( + featmap_sizes[i], self.point_strides[i]) multi_level_points.append(points) - points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)] + points_list = [[point.clone() for point in multi_level_points] + for _ in range(num_imgs)] # for each image, we compute valid flags of multi level grids valid_flag_list = [] @@ -273,7 +323,8 @@ def get_points(self, featmap_sizes, img_metas): h, w, _ = img_meta['pad_shape'] valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h) valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w) - flags = self.point_generators[i].valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w)) + flags = self.point_generators[i].valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) multi_level_flags.append(flags) valid_flag_list.append(multi_level_flags) @@ -287,21 +338,14 @@ def centers_to_bboxes(self, point_list): bbox = [] for i_lvl in range(len(self.point_strides)): scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5 - bbox_shift = torch.Tensor([-scale, -scale, scale, scale]).view(1, 4).type_as(point[0]) - bbox_center = torch.cat([point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1) + bbox_shift = torch.Tensor([-scale, -scale, scale, + scale]).view(1, 4).type_as(point[0]) + bbox_center = torch.cat( + [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1) bbox.append(bbox_center + bbox_shift) bbox_list.append(bbox) return bbox_list - def yx_to_xy(self, pts): - """Change the points offset from y first to x first. - """ - pts_y = pts[..., 0::2] - pts_x = pts[..., 1::2] - pts_xy = torch.stack([pts_x, pts_y], -1) - pts = pts_xy.view(*pts.shape[:-1], -1) - return pts - def offset_to_pts(self, center_list, pred_list): """Change from point offset to point coordinate. """ @@ -309,33 +353,45 @@ def offset_to_pts(self, center_list, pred_list): for i_lvl in range(len(self.point_strides)): pts_lvl = [] for i_img in range(len(center_list)): - pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points) + pts_center = center_list[i_img][i_lvl][:, :2].repeat( + 1, self.num_points) pts_shift = pred_list[i_lvl][i_img] - yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points) - xy_pts_shift = self.yx_to_xy(yx_pts_shift) + yx_pts_shift = pts_shift.permute(1, 2, 0).view( + -1, 2 * self.num_points) + y_pts_shift = yx_pts_shift[..., 0::2] + x_pts_shift = yx_pts_shift[..., 1::2] + xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1) + xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center pts_lvl.append(pts) pts_lvl = torch.stack(pts_lvl, 0) pts_list.append(pts_lvl) return pts_list - def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels, label_weights, - bbox_gt_init, bbox_weights_init, bbox_gt_refine, bbox_weights_refine, - stride, num_total_samples_init, num_total_samples_refine): + def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels, + label_weights, bbox_gt_init, bbox_weights_init, + bbox_gt_refine, bbox_weights_refine, stride, + num_total_samples_init, num_total_samples_refine): # classification loss labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) - cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) loss_cls = self.loss_cls( - cls_score, labels, label_weights, avg_factor=num_total_samples_refine) + cls_score, + labels, + label_weights, + avg_factor=num_total_samples_refine) # points loss bbox_gt_init = bbox_gt_init.reshape(-1, 4) bbox_weights_init = bbox_weights_init.reshape(-1, 4) - bbox_pred_init = self.transform_box(pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False) + bbox_pred_init = self.points2bbox( + pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False) bbox_gt_refine = bbox_gt_refine.reshape(-1, 4) bbox_weights_refine = bbox_weights_refine.reshape(-1, 4) - bbox_pred_refine = self.transform_box(pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False) + bbox_pred_refine = self.points2bbox( + pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False) normalize_term = self.point_base_scale * stride loss_pts_init = self.loss_bbox_init( bbox_pred_init / normalize_term, @@ -363,12 +419,20 @@ def loss(self, label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 # target for initial stage - proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas) - pts_coordinate_preds_init = self.offset_to_pts(proposal_list, pts_preds_init) - if cfg.init.assigner['type'] != 'PointAssigner': - proposal_list = self.centers_to_bboxes(proposal_list) + center_list, valid_flag_list = self.get_points(featmap_sizes, + img_metas) + pts_coordinate_preds_init = self.offset_to_pts(center_list, + pts_preds_init) + if cfg.init.assigner['type'] == 'PointAssigner': + # Assign target for center list + candidate_list = center_list + else: + # transform center list to bbox list and + # assign target for bbox list + bbox_list = self.centers_to_bboxes(center_list) + candidate_list = bbox_list cls_reg_targets_init = point_target( - proposal_list, + candidate_list, valid_flag_list, gt_bboxes, img_metas, @@ -377,21 +441,28 @@ def loss(self, gt_labels_list=gt_labels, label_channels=label_channels, sampling=self.sampling) - (*_, bbox_gt_list_init, proposal_list_init, - bbox_weights_list_init, num_total_pos_init, num_total_neg_init) = cls_reg_targets_init - num_total_samples_init = (num_total_pos_init + num_total_neg_init if self.sampling else num_total_pos_init) + (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, + num_total_pos_init, num_total_neg_init) = cls_reg_targets_init + num_total_samples_init = ( + num_total_pos_init + + num_total_neg_init if self.sampling else num_total_pos_init) # target for refinement stage - proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas) - pts_coordinate_preds_refine = self.offset_to_pts(proposal_list, pts_preds_refine) + center_list, valid_flag_list = self.get_points(featmap_sizes, + img_metas) + pts_coordinate_preds_refine = self.offset_to_pts( + center_list, pts_preds_refine) bbox_list = [] - for i_img, point in enumerate(proposal_list): + for i_img, center in enumerate(center_list): bbox = [] for i_lvl in range(len(pts_preds_refine)): - bbox_preds_init = self.transform_box(pts_preds_init[i_lvl].detach()) + bbox_preds_init = self.points2bbox( + pts_preds_init[i_lvl].detach()) bbox_shift = bbox_preds_init * self.point_strides[i_lvl] - bbox_center = torch.cat([point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1) - bbox.append(bbox_center + bbox_shift[i_img].permute(1, 2, 0).contiguous().view(-1, 4)) + bbox_center = torch.cat( + [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1) + bbox.append(bbox_center + + bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4)) bbox_list.append(bbox) cls_reg_targets_refine = point_target( bbox_list, @@ -403,10 +474,12 @@ def loss(self, gt_labels_list=gt_labels, label_channels=label_channels, sampling=self.sampling) - (labels_list, label_weights_list, bbox_gt_list_refine, proposal_list_refine, - bbox_weights_list_refine, num_total_pos_refine, num_total_neg_refine) = cls_reg_targets_refine + (labels_list, label_weights_list, bbox_gt_list_refine, + candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine, + num_total_neg_refine) = cls_reg_targets_refine num_total_samples_refine = ( - num_total_pos_refine + num_total_neg_refine if self.sampling else num_total_pos_refine) + num_total_pos_refine + + num_total_neg_refine if self.sampling else num_total_pos_refine) # compute loss losses_cls, losses_pts_init, losses_pts_refine = multi_apply( @@ -423,15 +496,26 @@ def loss(self, self.point_strides, num_total_samples_init=num_total_samples_init, num_total_samples_refine=num_total_samples_refine) - loss_dict_all = {'loss_cls': losses_cls, - 'loss_pts_init': losses_pts_init, - 'loss_pts_refine': losses_pts_refine} + loss_dict_all = { + 'loss_cls': losses_cls, + 'loss_pts_init': losses_pts_init, + 'loss_pts_refine': losses_pts_refine + } return loss_dict_all - def get_bboxes(self, cls_scores, pts_preds_init, pts_preds_refine, img_metas, cfg, - rescale=False, nms=True): + def get_bboxes(self, + cls_scores, + pts_preds_init, + pts_preds_refine, + img_metas, + cfg, + rescale=False, + nms=True): assert len(cls_scores) == len(pts_preds_refine) - bbox_preds_refine = [self.transform_box(pts_pred_refine) for pts_pred_refine in pts_preds_refine] + bbox_preds_refine = [ + self.points2bbox(pts_pred_refine) + for pts_pred_refine in pts_preds_refine + ] num_levels = len(cls_scores) mlvl_points = [ self.point_generators[i].grid_points(cls_scores[i].size()[-2:], @@ -444,12 +528,14 @@ def get_bboxes(self, cls_scores, pts_preds_init, pts_preds_refine, img_metas, cf cls_scores[i][img_id].detach() for i in range(num_levels) ] bbox_pred_list = [ - bbox_preds_refine[i][img_id].detach() for i in range(num_levels) + bbox_preds_refine[i][img_id].detach() + for i in range(num_levels) ] img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list, - mlvl_points, img_shape, scale_factor, cfg, rescale, nms) + mlvl_points, img_shape, + scale_factor, cfg, rescale, nms) result_list.append(proposals) return result_list @@ -465,10 +551,11 @@ def get_bboxes_single(self, assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] - for i_lvl, (cls_score, bbox_pred, points) in enumerate(zip(cls_scores, bbox_preds, mlvl_points)): + for i_lvl, (cls_score, bbox_pred, points) in enumerate( + zip(cls_scores, bbox_preds, mlvl_points)): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - cls_score = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels) + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: @@ -501,8 +588,9 @@ def get_bboxes_single(self, padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) if nms: - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) + det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) return det_bboxes, det_labels else: return mlvl_bboxes, mlvl_scores