diff --git a/mmocr/datasets/pipelines/dbnet_transforms.py b/mmocr/datasets/pipelines/dbnet_transforms.py index 3be9ed8ea..40801600c 100644 --- a/mmocr/datasets/pipelines/dbnet_transforms.py +++ b/mmocr/datasets/pipelines/dbnet_transforms.py @@ -84,36 +84,47 @@ def __call__(self, results): def may_augment_annotation(self, aug, shape, target_shape, results): if aug is None: return results + + # augment polygon mask for key in results['mask_fields']: - # augment polygon mask - masks = [] - for mask in results[key]: - masks.append( - [self.may_augment_poly(aug, shape, target_shape, mask[0])]) + masks = self.may_augment_poly(aug, shape, results[key]) if len(masks) > 0: results[key] = PolygonMasks(masks, *target_shape[:2]) + # augment bbox for key in results['bbox_fields']: - # augment bbox - bboxes = [] - for bbox in results[key]: - bbox = self.may_augment_poly(aug, shape, target_shape, bbox) - bboxes.append(bbox) + bboxes = self.may_augment_poly( + aug, shape, results[key], mask_flag=False) results[key] = np.zeros(0) if len(bboxes) > 0: results[key] = np.stack(bboxes) return results - def may_augment_poly(self, aug, img_shape, target_shape, poly): - # poly n x 2 - poly = poly.reshape(-1, 2) - keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] - keypoints = aug.augment_keypoints( - [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints - poly = [[p.x, p.y] for p in keypoints] - poly = np.array(poly).flatten() - return poly + def may_augment_poly(self, aug, img_shape, polys, mask_flag=True): + key_points, poly_point_nums = [], [] + for poly in polys: + if mask_flag: + poly = poly[0] + poly = poly.reshape(-1, 2) + key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly]) + poly_point_nums.append(poly.shape[0]) + key_points = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints=key_points, + shape=img_shape)])[0].keypoints + + new_polys = [] + start_idx = 0 + for poly_point_num in poly_point_nums: + new_poly = [] + for key_point in key_points[start_idx:(start_idx + + poly_point_num)]: + new_poly.append([key_point.x, key_point.y]) + start_idx += poly_point_num + new_poly = np.array(new_poly).flatten() + new_polys.append([new_poly] if mask_flag else new_poly) + + return new_polys def __repr__(self): repr_str = self.__class__.__name__ diff --git a/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py b/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py index 1e21cfc01..69c664b35 100644 --- a/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py @@ -100,8 +100,9 @@ def ignore_texts(self, results, ignore_tags): mask for i, mask in enumerate(results['gt_labels']) if not ignore_tags[i] ]) + new_ignore_tags = [ignore for ignore in ignore_tags if not ignore] - return results + return results, new_ignore_tags def generate_thr_map(self, img_size, polygons): """Generate threshold map. @@ -149,12 +150,15 @@ def draw_border_map(self, polygon, canvas, mask): else: print(f'padding {polygon} with {distance} gets {padded_polygon}') padded_polygon = polygon.copy().astype(np.int32) - cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) x_min = padded_polygon[:, 0].min() x_max = padded_polygon[:, 0].max() y_min = padded_polygon[:, 1].min() y_max = padded_polygon[:, 1].max() + + if x_max <= 0 or y_max <= 0: + return + width = x_max - x_min + 1 height = y_max - y_min + 1 @@ -180,6 +184,16 @@ def draw_border_map(self, polygon, canvas, mask): x_max_valid = min(max(0, x_max), canvas.shape[1] - 1) y_min_valid = min(max(0, y_min), canvas.shape[0] - 1) y_max_valid = min(max(0, y_max), canvas.shape[0] - 1) + + if x_min_valid - x_min >= distance_map.shape[ + 1] or y_min_valid - y_min >= distance_map.shape[0]: + return + if x_max_valid - x_max + width <= 0: + return + if y_max_valid - y_max + height <= 0: + return + + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) canvas[y_min_valid:y_max_valid + 1, x_min_valid:x_max_valid + 1] = np.fmax( 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max + @@ -198,25 +212,29 @@ def generate_targets(self, results): results (dict): The output result dictionary. """ assert isinstance(results, dict) - polygons = results['gt_masks'].masks + if 'bbox_fields' in results: results['bbox_fields'].clear() + ignore_tags = self.find_invalid(results) + results, ignore_tags = self.ignore_texts(results, ignore_tags) + h, w, _ = results['img_shape'] + polygons = results['gt_masks'].masks + # generate gt_shrink_kernel gt_shrink, ignore_tags = self.generate_kernels((h, w), polygons, self.shrink_ratio, ignore_tags=ignore_tags) - results = self.ignore_texts(results, ignore_tags) - - # polygons and polygons_ignore reassignment. - polygons = results['gt_masks'].masks + results, ignore_tags = self.ignore_texts(results, ignore_tags) + # genenrate gt_shrink_mask polygons_ignore = results['gt_masks_ignore'].masks - gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore) + # generate gt_threshold and gt_threshold_mask + polygons = results['gt_masks'].masks gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons) results['mask_fields'].clear() # rm gt_masks encoded by polygons