From 023c813a004e153027c7ed6acd99664f1d7356fb Mon Sep 17 00:00:00 2001 From: Mountchicken Date: Fri, 10 Feb 2023 15:48:43 +0800 Subject: [PATCH 1/2] support multi scale train --- mmocr/models/textdet/module_losses/db_module_loss.py | 7 ++++--- mmocr/models/textdet/module_losses/drrg_module_loss.py | 2 +- mmocr/models/textdet/module_losses/fce_module_loss.py | 2 +- mmocr/models/textdet/module_losses/pan_module_loss.py | 4 ++-- .../models/textdet/module_losses/textsnake_module_loss.py | 8 ++++---- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mmocr/models/textdet/module_losses/db_module_loss.py b/mmocr/models/textdet/module_losses/db_module_loss.py index ba8487310..bfa9846ff 100644 --- a/mmocr/models/textdet/module_losses/db_module_loss.py +++ b/mmocr/models/textdet/module_losses/db_module_loss.py @@ -240,7 +240,7 @@ def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple: if self._is_poly_invalid(polygon): ignore_flags[idx] = True gt_shrink, ignore_flags = self._generate_kernels( - data_sample.img_shape, + data_sample.batch_input_shape, gt_instances.polygons, self.shrink_ratio, ignore_flags=ignore_flags) @@ -249,9 +249,10 @@ def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple: gt_shrink = gt_shrink > 0 gt_shrink_mask = self._generate_effective_mask( - data_sample.img_shape, gt_instances[ignore_flags].polygons) + data_sample.batch_input_shape, gt_instances[ignore_flags].polygons) gt_thr, gt_thr_mask = self._generate_thr_map( - data_sample.img_shape, gt_instances[~ignore_flags].polygons) + data_sample.batch_input_shape, + gt_instances[~ignore_flags].polygons) # to_tensor gt_shrink = torch.from_numpy(gt_shrink).unsqueeze(0).float() diff --git a/mmocr/models/textdet/module_losses/drrg_module_loss.py b/mmocr/models/textdet/module_losses/drrg_module_loss.py index 51923ef0a..3757654f4 100644 --- a/mmocr/models/textdet/module_losses/drrg_module_loss.py +++ b/mmocr/models/textdet/module_losses/drrg_module_loss.py @@ -282,7 +282,7 @@ def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple: polygons = gt_instances[~ignore_flags].polygons ignored_polygons = gt_instances[ignore_flags].polygons - h, w = data_sample.img_shape + h, w = data_sample.batch_input_shape gt_text_mask = self._generate_text_region_mask((h, w), polygons) gt_mask = self._generate_effective_mask((h, w), ignored_polygons) diff --git a/mmocr/models/textdet/module_losses/fce_module_loss.py b/mmocr/models/textdet/module_losses/fce_module_loss.py index c833c1778..da088ba93 100644 --- a/mmocr/models/textdet/module_losses/fce_module_loss.py +++ b/mmocr/models/textdet/module_losses/fce_module_loss.py @@ -211,7 +211,7 @@ def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple: tuple[Tensor]: A tuple of three tensors from three different feature level as the targets of one prediction. """ - img_size = data_sample.img_shape[:2] + img_size = data_sample.batch_input_shape[:2] text_polys = data_sample.gt_instances.polygons ignore_flags = data_sample.gt_instances.ignored diff --git a/mmocr/models/textdet/module_losses/pan_module_loss.py b/mmocr/models/textdet/module_losses/pan_module_loss.py index 6a5a6685a..7c36d2992 100644 --- a/mmocr/models/textdet/module_losses/pan_module_loss.py +++ b/mmocr/models/textdet/module_losses/pan_module_loss.py @@ -157,14 +157,14 @@ def _get_target_single(self, data_sample: TextDetDataSample for ratio in self.shrink_ratio: # TODO pass `gt_ignored` to `_generate_kernels` gt_kernel, _ = self._generate_kernels( - data_sample.img_shape, + data_sample.batch_input_shape, gt_polygons, ratio, ignore_flags=None, max_shrink_dist=self.max_shrink_dist) gt_kernels.append(gt_kernel) gt_polygons_ignored = data_sample.gt_instances[gt_ignored].polygons - gt_mask = self._generate_effective_mask(data_sample.img_shape, + gt_mask = self._generate_effective_mask(data_sample.batch_input_shape, gt_polygons_ignored) gt_kernels = np.stack(gt_kernels, axis=0) diff --git a/mmocr/models/textdet/module_losses/textsnake_module_loss.py b/mmocr/models/textdet/module_losses/textsnake_module_loss.py index 651a74755..ae630ffeb 100644 --- a/mmocr/models/textdet/module_losses/textsnake_module_loss.py +++ b/mmocr/models/textdet/module_losses/textsnake_module_loss.py @@ -203,14 +203,14 @@ def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple: polygons = gt_instances[~ignore_flags].polygons ignored_polygons = gt_instances[ignore_flags].polygons - gt_text_mask = self._generate_text_region_mask(data_sample.img_shape, - polygons) - gt_mask = self._generate_effective_mask(data_sample.img_shape, + gt_text_mask = self._generate_text_region_mask( + data_sample.batch_input_shape, polygons) + gt_mask = self._generate_effective_mask(data_sample.batch_input_shape, ignored_polygons) (gt_center_region_mask, gt_radius_map, gt_sin_map, gt_cos_map) = self._generate_center_mask_attrib_maps( - data_sample.img_shape, polygons) + data_sample.batch_input_shape, polygons) return (gt_text_mask, gt_mask, gt_center_region_mask, gt_radius_map, gt_sin_map, gt_cos_map) From 0e2bd7553fd1c640f30b1349d99cc2b65e62ba00 Mon Sep 17 00:00:00 2001 From: Mountchicken Date: Fri, 10 Feb 2023 16:37:24 +0800 Subject: [PATCH 2/2] fix lint --- .../test_textdet/test_module_losses/test_db_module_loss.py | 2 +- .../test_module_losses/test_drrg_module_loss.py | 6 +++--- .../test_textdet/test_module_losses/test_fce_module_loss.py | 3 ++- .../test_textdet/test_module_losses/test_pan_module_loss.py | 2 +- .../test_textdet/test_module_losses/test_pse_module_loss.py | 2 +- .../test_module_losses/test_textsnake_module_loss.py | 2 +- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_models/test_textdet/test_module_losses/test_db_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_db_module_loss.py index a9882e7c1..dae91f63e 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_db_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_db_module_loss.py @@ -15,7 +15,7 @@ def setUp(self) -> None: self.db_loss = DBModuleLoss(thr_min=0.3, thr_max=0.7) self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(40, 40)), + metainfo=dict(img_shape=(40, 40), batch_input_shape=(40, 40)), gt_instances=InstanceData( polygons=np.array([ [0, 0, 10, 0, 10, 10, 0, 10], diff --git a/tests/test_models/test_textdet/test_module_losses/test_drrg_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_drrg_module_loss.py index 7edfdbcc4..dceeee736 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_drrg_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_drrg_module_loss.py @@ -18,7 +18,7 @@ def setUp(self) -> None: self.preds = (preds_maps, gcn_pred, gt_labels) self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(64, 64)), + metainfo=dict(img_shape=(64, 64), batch_input_shape=(64, 64)), gt_instances=InstanceData( polygons=[ np.array([4, 2, 30, 2, 30, 10, 4, 10]), @@ -55,7 +55,7 @@ def test_get_targets(self): # test generate_targets with blank polygon masks blank_data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(20, 20)), + metainfo=dict(img_shape=(20, 20), batch_input_shape=(20, 20)), gt_instances=InstanceData( polygons=[], ignored=torch.BoolTensor([]))) ] @@ -77,7 +77,7 @@ def test_get_targets(self): # test generate_targets with one proposed text component data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(20, 30)), + metainfo=dict(img_shape=(20, 30), batch_input_shape=(20, 30)), gt_instances=InstanceData( polygons=[np.array([13, 6, 17, 6, 17, 14, 13, 14])], ignored=torch.BoolTensor([False]))) diff --git a/tests/test_models/test_textdet/test_module_losses/test_fce_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_fce_module_loss.py index c656c1e58..5d1a9992f 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_fce_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_fce_module_loss.py @@ -15,7 +15,8 @@ def setUp(self) -> None: self.fce_loss = FCEModuleLoss(fourier_degree=5, num_sample=400) self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(320, 320)), + metainfo=dict( + img_shape=(320, 320), batch_input_shape=(320, 320)), gt_instances=InstanceData( polygons=np.array([ [0, 0, 10, 0, 10, 10, 0, 10], diff --git a/tests/test_models/test_textdet/test_module_losses/test_pan_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_pan_module_loss.py index 3e7d43965..201dc4ce0 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_pan_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_pan_module_loss.py @@ -17,7 +17,7 @@ def setUp(self) -> None: self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(40, 40)), + metainfo=dict(img_shape=(40, 40), batch_input_shape=(40, 40)), gt_instances=InstanceData( polygons=np.array([ [0, 0, 10, 0, 10, 10, 0, 10], diff --git a/tests/test_models/test_textdet/test_module_losses/test_pse_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_pse_module_loss.py index 5a2d591f8..ed5e83d60 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_pse_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_pse_module_loss.py @@ -16,7 +16,7 @@ class TestPSEModuleLoss(TestCase): def setUp(self) -> None: self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(40, 40)), + metainfo=dict(img_shape=(40, 40), batch_input_shape=(40, 40)), gt_instances=InstanceData( polygons=np.array([ [0, 0, 10, 0, 10, 10, 0, 10], diff --git a/tests/test_models/test_textdet/test_module_losses/test_textsnake_module_loss.py b/tests/test_models/test_textdet/test_module_losses/test_textsnake_module_loss.py index 8cc342315..c2be638f6 100644 --- a/tests/test_models/test_textdet/test_module_losses/test_textsnake_module_loss.py +++ b/tests/test_models/test_textdet/test_module_losses/test_textsnake_module_loss.py @@ -16,7 +16,7 @@ def setUp(self) -> None: self.data_samples = [ TextDetDataSample( - metainfo=dict(img_shape=(3, 10)), + metainfo=dict(img_shape=(3, 10), batch_input_shape=(3, 10)), gt_instances=InstanceData( polygons=np.array([ [0, 0, 1, 0, 1, 1, 0, 1],