Skip to content

Commit

Permalink
Merge pull request #53 from worldcoin/yichen1-ai-5001-add-sharpness-e…
Browse files Browse the repository at this point in the history
…stimation-to-open-iris

Add sharpness estimation and a validator to reject blurry images
  • Loading branch information
wiktorlazarski authored Nov 25, 2024
2 parents 5691724 + 11f18d7 commit 8f4a692
Show file tree
Hide file tree
Showing 23 changed files with 305 additions and 9 deletions.
7 changes: 6 additions & 1 deletion src/iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Offgaze,
PupilToIrisProperty,
SegmentationMap,
Sharpness,
)
from iris.io.errors import (
BoundingBoxEstimationError,
Expand All @@ -34,12 +35,14 @@
IRISPipelineError,
IsPupilInsideIrisValidatorError,
LandmarkEstimationError,
MaskTooSmallError,
MatcherError,
NormalizationError,
OcclusionError,
OffgazeEstimationError,
ProbeSchemaError,
PupilIrisPropertyEstimationError,
SharpnessEstimationError,
VectorizationError,
)
from iris.nodes.aggregation.noise_mask_union import NoiseMaskUnion
Expand All @@ -52,6 +55,7 @@
from iris.nodes.eye_properties_estimation.moment_of_area import MomentOfArea
from iris.nodes.eye_properties_estimation.occlusion_calculator import OcclusionCalculator
from iris.nodes.eye_properties_estimation.pupil_iris_property_calculator import PupilIrisPropertyCalculator
from iris.nodes.eye_properties_estimation.sharpness_estimation import SharpnessEstimation
from iris.nodes.geometry_estimation.fusion_extrapolation import FusionExtrapolation
from iris.nodes.geometry_estimation.linear_extrapolation import LinearExtrapolation
from iris.nodes.geometry_estimation.lsq_ellipse_fit_with_refinement import LSQEllipseFitWithRefinement
Expand All @@ -65,8 +69,8 @@
from iris.nodes.iris_response.probe_schemas.regular_probe_schema import RegularProbeSchema
from iris.nodes.iris_response_refinement.fragile_bits_refinement import FragileBitRefinement
from iris.nodes.matcher.hamming_distance_matcher import HammingDistanceMatcher
from iris.nodes.matcher.simple_hamming_distance_matcher import SimpleHammingDistanceMatcher
from iris.nodes.matcher.hamming_distance_matcher_interface import BatchMatcher, Matcher
from iris.nodes.matcher.simple_hamming_distance_matcher import SimpleHammingDistanceMatcher
from iris.nodes.normalization.linear_normalization import LinearNormalization
from iris.nodes.normalization.nonlinear_normalization import NonlinearNormalization
from iris.nodes.normalization.perspective_normalization import PerspectiveNormalization
Expand All @@ -82,6 +86,7 @@
OffgazeValidator,
PolygonsLengthValidator,
Pupil2IrisPropertyValidator,
SharpnessValidator,
)
from iris.nodes.vectorization.contouring import ContouringAlgorithm
from iris.orchestration import error_managers, output_builders, pipeline_dataclasses
Expand Down
26 changes: 26 additions & 0 deletions src/iris/io/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,32 @@ def deserialize(data: float) -> Offgaze:
return Offgaze(score=data)


class Sharpness(ImmutableModel):
"""Data holder for Sharpness score."""

score: float = Field(..., ge=0.0)

def serialize(self) -> float:
"""Serialize Sharpness object.
Returns:
float: Serialized object.
"""
return self.score

@staticmethod
def deserialize(data: float) -> Sharpness:
"""Deserialize Sharpness object.
Args:
data (float): Serialized object to float.
Returns:
Sharpness: Deserialized object.
"""
return Sharpness(score=data)


class PupilToIrisProperty(ImmutableModel):
"""Data holder for pupil-ro-iris ratios."""

Expand Down
12 changes: 12 additions & 0 deletions src/iris/io/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ class VectorizationError(Exception):
pass


class SharpnessEstimationError(Exception):
"""SharpnessEstimation Error class."""

pass


class MaskTooSmallError(Exception):
"""Mask is too small Error class."""

pass


class MatcherError(Exception):
"""Matcher module Error class."""

Expand Down
67 changes: 67 additions & 0 deletions src/iris/nodes/eye_properties_estimation/sharpness_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import List, Tuple

import cv2
import numpy as np
from pydantic import Field, validator

import iris.io.validators as pydantic_v
from iris.callbacks.callback_interface import Callback
from iris.io.class_configs import Algorithm
from iris.io.dataclasses import NormalizedIris, Sharpness


class SharpnessEstimation(Algorithm):
"""Calculate sharpness of the normalized iris.
The goal of this algorithm is to calculate the sharpness of the normalized iris using the variance of Laplacian.
LIMITATIONS:
This method may be biased against dark images with inadequate lighting.
"""

class Parameters(Algorithm.Parameters):
"""Parameters class for SharpnessEstimation objects.
lap_ksize (int): Laplacian kernel size, must be odd integer no larger than 31.
erosion_ksize (Tuple[int, int]): Mask erosion kernel size, must be odd integers.
"""

lap_ksize: int = Field(..., gt=0, le=31)
erosion_ksize: Tuple[int, int] = Field(..., gt=0)

_is_odd0 = validator("lap_ksize", allow_reuse=True)(pydantic_v.is_odd)
_is_odd = validator("erosion_ksize", allow_reuse=True, each_item=True)(pydantic_v.is_odd)

__parameters_type__ = Parameters

def __init__(
self,
lap_ksize: int = 11,
erosion_ksize: Tuple[int, int] = (29, 15),
callbacks: List[Callback] = [],
) -> None:
"""Assign parameters.
Args:
lap_ksize (int, optional): kernal size for Laplacian. Defaults to 11.
erosion_ksize (Tuple[int, int], optional): kernal size for mask erosion. Defaults to (29,15).
callbacks (List[Callback]): callbacks list. Defaults to [].
"""
super().__init__(lap_ksize=lap_ksize, erosion_ksize=erosion_ksize, callbacks=callbacks)

def run(self, normalization_output: NormalizedIris) -> Sharpness:
"""Calculate sharpness of the normalized iris.
Args:
normalization_output (NormalizedIris): Normalized iris.
Returns:
Sharpness: Sharpness object.
"""
output_im = cv2.Laplacian(normalization_output.normalized_image / 255, cv2.CV_32F, ksize=self.params.lap_ksize)
mask_im = cv2.erode(
normalization_output.normalized_mask.astype(np.uint8), kernel=np.ones(self.params.erosion_ksize, np.uint8)
)
sharpness_score = output_im[mask_im == 1].std() if np.sum(mask_im == 1) > 0 else 0.0
return Sharpness(score=sharpness_score)
53 changes: 49 additions & 4 deletions src/iris/nodes/validators/object_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iris.io.errors as E
from iris.callbacks.callback_interface import Callback
from iris.io.class_configs import Algorithm
from iris.io.dataclasses import EyeOcclusion, GeometryPolygons, IrisTemplate, Offgaze, PupilToIrisProperty
from iris.io.dataclasses import EyeOcclusion, GeometryPolygons, IrisTemplate, Offgaze, PupilToIrisProperty, Sharpness
from iris.utils.math import polygon_length


Expand Down Expand Up @@ -316,13 +316,58 @@ def on_execute_start(self, input_polygons: GeometryPolygons, *args, **kwargs) ->
self.run(input_polygons)


class SharpnessValidator(Callback, Algorithm):
"""Validate that the normalized image is not too blurry.
Raises:
E.SharpnessEstimationError: If the sharpness score is below threshold.
"""

class Parameters(Algorithm.Parameters):
"""Parameters class for SharpnessValidator objects."""

min_sharpness: float = Field(..., ge=0.0)

__parameters_type__ = Parameters

def __init__(self, min_sharpness: float = 0.0) -> None:
"""Assign parameters.
Args:
min_sharpness (float): Minimum sharpness score. Sharpness computation min threshold that allows further sample processing. Defaults to 0.0 (by default every check will result in success).
"""
super().__init__(min_sharpness=min_sharpness)

def run(self, val_arguments: Sharpness) -> None:
"""Validate of sharpness estimation algorithm.
Args:
val_arguments (Sharpness): Computed result.
Raises:
E.SharpnessEstimationError: Raised if the sharpness score is below the desired threshold.
"""
if val_arguments.score < self.params.min_sharpness:
raise E.SharpnessEstimationError(
f"sharpness={val_arguments.score} < min_sharpness={self.params.min_sharpness}"
)

def on_execute_end(self, result: Sharpness) -> None:
"""Wrap for validate method so that validator can be used as a Callback.
Args:
result (Sharpness): Sharpness resulted from computations.
"""
self.run(result)


class IsMaskTooSmallValidator(Callback, Algorithm):
"""Validate that the masked part of the IrisTemplate is small enough.
The larger the mask, the less reliable information is available to create a robust identity.
Raises:
E.EncoderError: If the total number of non-masked bits is below threshold.
E.MaskTooSmallError: If the total number of non-masked bits is below threshold.
"""

class Parameters(Algorithm.Parameters):
Expand All @@ -347,12 +392,12 @@ def run(self, val_arguments: IrisTemplate) -> None:
val_arguments (IrisTemplate): IrisTemplate to be validated.
Raises:
E.EncoderError: Raised if the total mask codes size is below the desired threshold.
E.MaskTooSmallError: Raised if the total mask codes size is below the desired threshold.
"""
maskcodes_size = np.sum(val_arguments.mask_codes)

if maskcodes_size < self.params.min_maskcodes_size:
raise E.EncoderError(
raise E.MaskTooSmallError(
f"Valid mask codes size is too small: Got {maskcodes_size} px, min {self.params.min_maskcodes_size} px."
)

Expand Down
1 change: 1 addition & 0 deletions src/iris/orchestration/output_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __get_metadata(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]:
"occlusion90": __safe_serialize(call_trace["occlusion90_calculator"]),
"occlusion30": __safe_serialize(call_trace["occlusion30_calculator"]),
"iris_bbox": __safe_serialize(call_trace["bounding_box_estimation"]),
"sharpness_score": __safe_serialize(call_trace["sharpness_estimation"]),
}


Expand Down
12 changes: 12 additions & 0 deletions src/iris/pipelines/confs/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ pipeline:
source_node: eye_orientation
callbacks:

- name: sharpness_estimation
algorithm:
class_name: iris.SharpnessEstimation
params: {}
inputs:
- name: normalization_output
source_node: normalization
callbacks:
- class_name: iris.nodes.validators.object_validators.SharpnessValidator
params:
min_sharpness: 461.0

- name: filter_bank
algorithm:
class_name: iris.ConvFilterBank
Expand Down
1 change: 1 addition & 0 deletions src/iris/pipelines/iris_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class IRISPipeline(Algorithm):
iris.nodes.validators.object_validators.OffgazeValidator,
iris.nodes.validators.object_validators.OcclusionValidator,
iris.nodes.validators.object_validators.IsPupilInsideIrisValidator,
iris.nodes.validators.object_validators.SharpnessValidator,
iris.nodes.validators.object_validators.IsMaskTooSmallValidator,
iris.nodes.validators.cross_object_validators.EyeCentersInsideImageValidator,
iris.nodes.validators.cross_object_validators.ExtrapolatedPolygonsInsideImageValidator,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import math
import os
import pickle
from typing import Any

from iris.nodes.eye_properties_estimation.sharpness_estimation import SharpnessEstimation


def load_mock_pickle(name: str) -> Any:
testdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "iris_response/mocks", "conv_filter_bank")
mock_path = os.path.join(testdir, f"{name}.pickle")
return pickle.load(open(mock_path, "rb"))


def test_sharpness_estimation() -> None:
normalized_iris = load_mock_pickle(name="normalized_iris")

sharpness_obj = SharpnessEstimation()
sharpness = sharpness_obj(normalized_iris)

assert math.isclose(sharpness.score, 880.9419555664062)

sharpness_obj = SharpnessEstimation(lap_ksize=7)
sharpness = sharpness_obj(normalized_iris)

assert math.isclose(sharpness.score, 5.179013252258301)

sharpness_obj = SharpnessEstimation(erosion_ksize=[13, 7])
sharpness = sharpness_obj(normalized_iris)

assert math.isclose(sharpness.score, 1013.1661376953125)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/e2e_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def compare_iris_pipeline_metadata_output(metadata_1: Dict[str, Any], metadata_2
],
decimal=4,
)
np.testing.assert_almost_equal(
metadata_2["sharpness_score"],
metadata_1["sharpness_score"],
decimal=4,
)


def compare_iris_pipeline_template_output(iris_template_1: Dict[str, Any], iris_template_2: Dict[str, Any]) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Tuple

import numpy as np
import pytest
from pydantic import ValidationError

from iris.nodes.eye_properties_estimation.sharpness_estimation import SharpnessEstimation


@pytest.mark.parametrize(
"lap_ksize",
[
pytest.param(0),
pytest.param("a"),
pytest.param(-10),
pytest.param(33),
pytest.param(2),
pytest.param(np.ones(3)),
],
ids=[
"lap_ksize should be larger than zero",
"lap_ksize should be int",
"lap_ksize should not be negative",
"lap_ksize should not be larger than 31",
"lap_ksize should be odd number",
"lap_ksize should not be array",
],
)
def test_sharpness_lap_ksize_raises_an_exception(lap_ksize: int) -> None:
with pytest.raises(ValidationError):
_ = SharpnessEstimation(lap_ksize=lap_ksize)


@pytest.mark.parametrize(
"erosion_ksize",
[
pytest.param((0, 5)),
pytest.param((1, "a")),
pytest.param((-10, 3)),
pytest.param((30, 5)),
pytest.param(np.ones(3)),
],
ids=[
"erosion_ksize should all be larger than zero",
"erosion_ksize should all be int",
"erosion_ksize should not be negative",
"erosion_ksize should be odd number",
"erosion_ksize should be a tuple of integer with length 2",
],
)
def test_sharpness_erosion_ksize_raises_an_exception(erosion_ksize: Tuple[int, int]) -> None:
with pytest.raises(ValidationError):
_ = SharpnessEstimation(erosion_ksize=erosion_ksize)
Loading

0 comments on commit 8f4a692

Please sign in to comment.