From 68965057dc8807e34409996ccb3769effd1edede Mon Sep 17 00:00:00 2001 From: Matt Clifford Date: Wed, 7 Aug 2024 11:49:24 +0100 Subject: [PATCH] RMSE added --- IQM_Vis/examples/kodak.py | 3 ++- IQM_Vis/metrics/IQMs.py | 30 ++++++++++++++++++++++++++++++ IQM_Vis/metrics/__init__.py | 3 +++ IQM_Vis/version.py | 2 +- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/IQM_Vis/examples/kodak.py b/IQM_Vis/examples/kodak.py index 7e924fe..770e6bd 100644 --- a/IQM_Vis/examples/kodak.py +++ b/IQM_Vis/examples/kodak.py @@ -15,7 +15,8 @@ def run(): image_list.sort() metrs = IQM_Vis.metrics.get_all_metrics() - metrs.pop('1-MS_SSIM') + if '1-MS_SSIM' in metrs: + metrs.pop('1-MS_SSIM') data = IQM_Vis.dataset_holder(image_list, metrs, # IQM_Vis.metrics.get_all_metric_images() diff --git a/IQM_Vis/metrics/IQMs.py b/IQM_Vis/metrics/IQMs.py index ffe23f4..43fd771 100644 --- a/IQM_Vis/metrics/IQMs.py +++ b/IQM_Vis/metrics/IQMs.py @@ -79,6 +79,36 @@ def __call__(self, im_ref, im_comp, **kwargs): return L2 else: return L2.mean() + +class RMSE: + '''Root Mean Squared Error between two images. Images must have the same + dimensions + + Args: + return_image (bool): Whether to return the image (Defaults to False which + will return a scalar value) + ''' + def __init__(self, return_image=False): + self.return_image = return_image + + def __call__(self, im_ref, im_comp, **kwargs): + '''When an instance is called + + Args: + im_ref (np.array): Reference image + im_comp (np.array): Comparison image + **kwargs: Arbitrary keyword arguments + + Returns: + score (np.array): RMSE (scalar if return_image is False, image if + return_image is True) + ''' + _check_shapes(im_ref, im_comp) + L2 = np.square(im_ref - im_comp) + if self.return_image: + return np.sqrt(L2) + else: + return np.sqrt(L2.mean()) class SSIM: '''Structural Similarity Index Measure between two images. Images must have diff --git a/IQM_Vis/metrics/__init__.py b/IQM_Vis/metrics/__init__.py index c269e2d..a8d8671 100644 --- a/IQM_Vis/metrics/__init__.py +++ b/IQM_Vis/metrics/__init__.py @@ -1,5 +1,6 @@ from IQM_Vis.metrics.IQMs import (MAE, MSE, + RMSE, SSIM, MS_SSIM, LPIPS, @@ -23,6 +24,7 @@ def get_all_metrics(): 'DISTS': DISTS(), 'LPIPS': LPIPS(), 'MAE': MAE(), + 'RMSE': RMSE(), } return all_metrics @@ -38,6 +40,7 @@ def get_all_metric_images(): 'SSIM': SSIM(return_image=True), # 'MS_SSIM': MS_SSIM(return_image=True), 'MAE': MAE(return_image=True), + 'RMSE': RMSE(return_image=True), } return all_metrics diff --git a/IQM_Vis/version.py b/IQM_Vis/version.py index bbcf011..0bd67ef 100644 --- a/IQM_Vis/version.py +++ b/IQM_Vis/version.py @@ -2,4 +2,4 @@ # License: BSD 3-Clause License # Changing the version number will action GitHub to push to PyPi the new version -__version__ = '0.2.5.91' \ No newline at end of file +__version__ = '0.2.5.92' \ No newline at end of file