forked from open-mmlab/mmocr
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add TextOCR dataset converter (open-mmlab#293)
* textocr converter for text recog * textocr converter for text detection * update documentation * remove unnecessary garbage collection lines * multi-processing textocr converter * json->mmcv, fix documentation
- Loading branch information
1 parent
8befec3
commit 7b072b0
Showing
3 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import argparse | ||
import math | ||
import os.path as osp | ||
|
||
import mmcv | ||
|
||
from mmocr.utils import convert_annotations | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Generate training and validation set of TextOCR ') | ||
parser.add_argument('root_path', help='Root dir path of TextOCR') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def collect_textocr_info(root_path, annotation_filename, print_every=1000): | ||
|
||
annotation_path = osp.join(root_path, annotation_filename) | ||
if not osp.exists(annotation_path): | ||
raise Exception( | ||
f'{annotation_path} not exists, please check and try again.') | ||
|
||
annotation = mmcv.load(annotation_path) | ||
|
||
# img_idx = img_start_idx | ||
img_infos = [] | ||
for i, img_info in enumerate(annotation['imgs'].values()): | ||
if i > 0 and i % print_every == 0: | ||
print(f'{i}/{len(annotation["imgs"].values())}') | ||
|
||
img_info['segm_file'] = annotation_path | ||
ann_ids = annotation['imgToAnns'][img_info['id']] | ||
anno_info = [] | ||
for ann_id in ann_ids: | ||
ann = annotation['anns'][ann_id] | ||
|
||
# Ignore illegible or non-English words | ||
text_label = ann['utf8_string'] | ||
iscrowd = 1 if text_label == '.' else 0 | ||
|
||
x, y, w, h = ann['bbox'] | ||
x, y = max(0, math.floor(x)), max(0, math.floor(y)) | ||
w, h = math.ceil(w), math.ceil(h) | ||
bbox = [x, y, w, h] | ||
segmentation = [max(0, int(x)) for x in ann['points']] | ||
anno = dict( | ||
iscrowd=iscrowd, | ||
category_id=1, | ||
bbox=bbox, | ||
area=ann['area'], | ||
segmentation=[segmentation]) | ||
anno_info.append(anno) | ||
img_info.update(anno_info=anno_info) | ||
img_infos.append(img_info) | ||
return img_infos | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
root_path = args.root_path | ||
print('Processing training set...') | ||
training_infos = collect_textocr_info(root_path, 'TextOCR_0.1_train.json') | ||
convert_annotations(training_infos, | ||
osp.join(root_path, 'instances_training.json')) | ||
print('Processing validation set...') | ||
val_infos = collect_textocr_info(root_path, 'TextOCR_0.1_val.json') | ||
convert_annotations(val_infos, osp.join(root_path, 'instances_val.json')) | ||
print('Finish') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import argparse | ||
import math | ||
import os | ||
import os.path as osp | ||
from functools import partial | ||
|
||
import mmcv | ||
|
||
from mmocr.utils.fileio import list_to_file | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Generate training and validation set of TextOCR ' | ||
'by cropping box image.') | ||
parser.add_argument('root_path', help='Root dir path of TextOCR') | ||
parser.add_argument( | ||
'n_proc', default=1, type=int, help='Number of processes to run') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def process_img(args, src_image_root, dst_image_root): | ||
# Dirty hack for multi-processing | ||
img_idx, img_info, anns = args | ||
src_img = mmcv.imread(osp.join(src_image_root, img_info['file_name'])) | ||
labels = [] | ||
for ann_idx, ann in enumerate(anns): | ||
text_label = ann['utf8_string'] | ||
|
||
# Ignore illegible or non-English words | ||
if text_label == '.': | ||
continue | ||
|
||
x, y, w, h = ann['bbox'] | ||
x, y = max(0, math.floor(x)), max(0, math.floor(y)) | ||
w, h = math.ceil(w), math.ceil(h) | ||
dst_img = src_img[y:y + h, x:x + w] | ||
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' | ||
dst_img_path = osp.join(dst_image_root, dst_img_name) | ||
mmcv.imwrite(dst_img, dst_img_path) | ||
labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' | ||
f' {text_label}') | ||
return labels | ||
|
||
|
||
def convert_textocr(root_path, | ||
dst_image_path, | ||
dst_label_filename, | ||
annotation_filename, | ||
img_start_idx=0, | ||
nproc=1): | ||
|
||
annotation_path = osp.join(root_path, annotation_filename) | ||
if not osp.exists(annotation_path): | ||
raise Exception( | ||
f'{annotation_path} not exists, please check and try again.') | ||
src_image_root = root_path | ||
|
||
# outputs | ||
dst_label_file = osp.join(root_path, dst_label_filename) | ||
dst_image_root = osp.join(root_path, dst_image_path) | ||
os.makedirs(dst_image_root, exist_ok=True) | ||
|
||
annotation = mmcv.load(annotation_path) | ||
|
||
process_img_with_path = partial( | ||
process_img, | ||
src_image_root=src_image_root, | ||
dst_image_root=dst_image_root) | ||
tasks = [] | ||
for img_idx, img_info in enumerate(annotation['imgs'].values()): | ||
ann_ids = annotation['imgToAnns'][img_info['id']] | ||
anns = [annotation['anns'][ann_id] for ann_id in ann_ids] | ||
tasks.append((img_idx + img_start_idx, img_info, anns)) | ||
labels_list = mmcv.track_parallel_progress( | ||
process_img_with_path, tasks, keep_order=True, nproc=nproc) | ||
final_labels = [] | ||
for label_list in labels_list: | ||
final_labels += label_list | ||
list_to_file(dst_label_file, final_labels) | ||
return len(annotation['imgs']) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
root_path = args.root_path | ||
print('Processing training set...') | ||
num_train_imgs = convert_textocr( | ||
root_path=root_path, | ||
dst_image_path='image', | ||
dst_label_filename='train_label.txt', | ||
annotation_filename='TextOCR_0.1_train.json', | ||
nproc=args.n_proc) | ||
print('Processing validation set...') | ||
convert_textocr( | ||
root_path=root_path, | ||
dst_image_path='image', | ||
dst_label_filename='val_label.txt', | ||
annotation_filename='TextOCR_0.1_val.json', | ||
img_start_idx=num_train_imgs, | ||
nproc=args.n_proc) | ||
print('Finish') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |