-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add OCR Decoding support - WIP #113
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .letterbox_resize import LetterboxResize | ||
from .mixup import MixUp | ||
from .mosaic import Mosaic4 | ||
from .ocr import OCRAugmentation | ||
|
||
__all__ = ["LetterboxResize", "MixUp", "Mosaic4"] | ||
__all__ = ["LetterboxResize", "MixUp", "Mosaic4", "OCRAugmentation"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from typing import Tuple | ||
import random | ||
import albumentations as A | ||
|
||
from albumentations.core.transforms_interface import ImageOnlyTransform | ||
|
||
import numpy as np | ||
from ..utils import AUGMENTATIONS | ||
|
||
|
||
@AUGMENTATIONS.register_module() | ||
class OCRAugmentation(ImageOnlyTransform): | ||
|
||
def __init__(self, image_size: Tuple[int, int], is_rgb: bool, is_train: bool): | ||
""" | ||
|
||
@param image_size: OCR model input shape. | ||
@type image_size: Tuple[int, int] | ||
@param is_rgb: True if image is RGB. False if image is GRAY. | ||
@type is_rgb: bool | ||
@param is_train: True if image is train. False if image is val/test. | ||
@type is_train: bool | ||
""" | ||
super(OCRAugmentation, self).__init__() | ||
self.transforms = A.Compose( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a set of some standard augmentations that are usually performed for OCR task or how is this defined? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also curious on this |
||
transforms=[ | ||
A.OneOf( | ||
transforms=[ | ||
A.RandomScale( | ||
scale_limit=(-0.5, 0), # (1 + low, 1 + high) => (0.5, 1) : random downscale (min half size) | ||
always_apply=True, | ||
p=1 | ||
), | ||
A.MotionBlur( | ||
blur_limit=(19, 21), | ||
p=1.0, | ||
always_apply=True, | ||
allow_shifted=False | ||
), | ||
A.Affine( | ||
translate_percent=(0.05, 0.07), | ||
scale=(0.7, 1.0), | ||
rotate=(-7, 7), | ||
shear=(5, 35), | ||
always_apply=True, | ||
p=1 | ||
) | ||
], | ||
p=0.2 | ||
), | ||
A.OneOf( | ||
transforms=[ | ||
A.ISONoise( | ||
color_shift=(0.01, 0.1), | ||
intensity=(0.1, 1.0), | ||
always_apply=True | ||
), | ||
A.GaussianBlur( | ||
blur_limit=(7, 9), # kernel | ||
sigma_limit=(0.1, 0.5), | ||
always_apply=True, | ||
p=0.2 | ||
), | ||
A.ColorJitter( | ||
brightness=(0.11, 1.0), | ||
contrast=0.5, | ||
saturation=0.5 | ||
) | ||
], | ||
p=0.2 | ||
), | ||
A.Compose( # resize to image_size with aspect ratio, pad if needed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resize and Normalize are already part of the default augmentations. Resize is always done (you can control if it keeps aspect ratio or not) and Normalize is also appended to list of augmentations (if used by luxonis-train, can be deactivated through config though). So is this needed here? |
||
transforms=[ | ||
A.LongestMaxSize(max_size=max(image_size), interpolation=1), | ||
A.PadIfNeeded(min_height=image_size[0], min_width=image_size[1], border_mode=0, | ||
value=(0, 0, 0)) | ||
] | ||
), | ||
A.Normalize( | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225) | ||
) if is_rgb else A.Normalize( | ||
mean=0.4453, | ||
std=0.2692 | ||
) | ||
|
||
] | ||
) if is_train else A.Compose( | ||
transforms=[ | ||
A.Compose( # resize to image_size with aspect ratio, pad if needed | ||
transforms=[ | ||
A.LongestMaxSize(max_size=max(image_size), interpolation=1), | ||
A.PadIfNeeded(min_height=image_size[0], min_width=image_size[1], border_mode=0, | ||
value=(0, 0, 0)) | ||
] | ||
), | ||
A.Normalize( | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225) | ||
) if is_rgb else A.Normalize( | ||
mean=0.4453, | ||
std=0.2692 | ||
) | ||
] | ||
) | ||
|
||
def apply( | ||
self, | ||
img: np.ndarray, | ||
**kwargs, | ||
) -> np.ndarray: | ||
"""Applies a series of OCR augmentations. | ||
|
||
@param img: Input image to which resize is applied. | ||
@type img: np.ndarray | ||
@return: Image with applied OCR augmentations. | ||
@rtype: np.ndarray | ||
""" | ||
|
||
img_out = self.transforms(image=img)["image"] | ||
return img_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ def load_annotation(name: str, js: str, data: Dict[str, Any]) -> "Annotation": | |
"PolylineSegmentationAnnotation": PolylineSegmentationAnnotation, | ||
"ArrayAnnotation": ArrayAnnotation, | ||
"LabelAnnotation": LabelAnnotation, | ||
"TextAnnotation": TextAnnotation, | ||
}[name](**json.loads(js), **data) | ||
|
||
|
||
|
@@ -92,6 +93,39 @@ def combine_to_numpy( | |
pass | ||
|
||
|
||
class TextAnnotation(Annotation): | ||
_label_type = LabelType.TEXT | ||
type_: Literal["text"] = Field("text", alias="type") | ||
|
||
text: str | ||
max_len: int | ||
|
||
def to_numpy( | ||
self, | ||
class_mapping: Dict[str, int], | ||
**_, | ||
) -> np.ndarray: | ||
text_label = np.zeros(self.max_len) | ||
for idx, char in enumerate(self.text): | ||
class_ = class_mapping.get(char, 0) | ||
text_label[idx] = class_ | ||
return text_label | ||
|
||
@staticmethod | ||
def combine_to_numpy( | ||
annotations: List["TextAnnotation"], | ||
class_mapping: Dict[str, int], | ||
**_, | ||
) -> np.ndarray: | ||
text_labels = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for idx, ann in enumerate(annotations): | ||
if text_labels is None: | ||
text_labels = np.zeros((len(annotations), ann.max_len)) | ||
text_labels[idx] = ann.to_numpy(class_mapping=class_mapping, max_len=ann.max_len) | ||
|
||
return text_labels | ||
|
||
|
||
class ClassificationAnnotation(Annotation): | ||
_label_type = LabelType.CLASSIFICATION | ||
type_: Literal["classification"] = Field("classification", alias="type") | ||
|
@@ -360,6 +394,7 @@ class DatasetRecord(BaseModelExtraForbid): | |
PolylineSegmentationAnnotation, | ||
ArrayAnnotation, | ||
LabelAnnotation, | ||
TextAnnotation | ||
] | ||
] = Field(None, discriminator="type_") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from collections import defaultdict | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Set, Tuple | ||
from typing import Any, Dict, List, Optional, Set, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
@@ -124,6 +124,7 @@ def __init__( | |
rich.progress.MofNCompleteColumn(), | ||
rich.progress.TimeRemainingColumn(), | ||
) | ||
self.global_metadata = {} | ||
|
||
@property | ||
def identifier(self) -> str: | ||
|
@@ -292,6 +293,9 @@ def set_classes(self, classes: List[str], task: Optional[str] = None) -> None: | |
self.fs.put_file(local_file, "metadata/classes.json") | ||
self._remove_temp_dir() | ||
|
||
def set_global_metadata(self, metadata: Dict[str, Any]) -> None: | ||
self.global_metadata = metadata | ||
Comment on lines
+296
to
+297
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to GCS datasets, I think we need a way to persist this via storage instead of just memory? Perhaps we could use the existing |
||
|
||
def get_classes( | ||
self, sync_mode: bool = False | ||
) -> Tuple[List[str], Dict[str, List[str]]]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -174,6 +174,7 @@ def _load_image_with_annotations(self, idx: int) -> Tuple[np.ndarray, Labels]: | |
|
||
uuid = self.instances[idx] | ||
df = self.df.loc[uuid] | ||
print(df) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. forgotten |
||
if self.dataset.bucket_storage == BucketStorage.LOCAL: | ||
matched = self.file_index[self.file_index["uuid"] == uuid] | ||
img_path = list(matched["original_filepath"])[0] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from typing import Any, Iterable | ||
from typing import Any, Iterable, List, Tuple | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
|
@@ -27,3 +28,44 @@ def _check_valid_array(path: str) -> bool: | |
) | ||
if not _check_valid_array(value): | ||
raise Exception(f"Array at path {value} is not a valid numpy array (.npy)") | ||
|
||
|
||
def validate_text_value( | ||
value: str, | ||
classes: List[str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The annotation of |
||
) -> Tuple[str, int]: | ||
"""Validates a text value to only contain valid classes. | ||
|
||
@param classes: valid character classes. | ||
@type classes: | ||
@type value: str | ||
@param value: text value to validate. | ||
@rtype: str | ||
@return: same input value if it's valid, raises a Warning and cleans/ignores invalid characters. | ||
""" | ||
clean_value = "" | ||
text_value, max_len = value | ||
for char in text_value: | ||
if char in classes: | ||
clean_value += char | ||
else: | ||
warnings.warn( | ||
f"Text annotations contain invalid char ({char}): default behaviour is to exclude undefined classes, " | ||
f"make sure to add it to your dataset classes." | ||
) | ||
return clean_value, max_len | ||
|
||
|
||
# def encode_text_value( | ||
# value: str, | ||
# max_len: int, | ||
# classes: List[str] | ||
# ) -> [np.ndarray, int, int]: | ||
# | ||
# text = value | ||
# text_label = np.zeros(max_len) | ||
# for char_idx, char in enumerate(text): | ||
# cls = classes.index(char) | ||
# text_label[char_idx] = cls | ||
# | ||
# return [text_label, len(text), max_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Le'ts keep it just
super().__init__()
. The arguments insuper
are a relic from python 2.