Skip to content

Commit

Permalink
RMSE added
Browse files Browse the repository at this point in the history
  • Loading branch information
mattclifford1 committed Aug 7, 2024
1 parent d130a99 commit 6896505
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
3 changes: 2 additions & 1 deletion IQM_Vis/examples/kodak.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def run():
image_list.sort()

metrs = IQM_Vis.metrics.get_all_metrics()
metrs.pop('1-MS_SSIM')
if '1-MS_SSIM' in metrs:
metrs.pop('1-MS_SSIM')
data = IQM_Vis.dataset_holder(image_list,
metrs,
# IQM_Vis.metrics.get_all_metric_images()
Expand Down
30 changes: 30 additions & 0 deletions IQM_Vis/metrics/IQMs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,36 @@ def __call__(self, im_ref, im_comp, **kwargs):
return L2
else:
return L2.mean()

class RMSE:
'''Root Mean Squared Error between two images. Images must have the same
dimensions
Args:
return_image (bool): Whether to return the image (Defaults to False which
will return a scalar value)
'''
def __init__(self, return_image=False):
self.return_image = return_image

def __call__(self, im_ref, im_comp, **kwargs):
'''When an instance is called
Args:
im_ref (np.array): Reference image
im_comp (np.array): Comparison image
**kwargs: Arbitrary keyword arguments
Returns:
score (np.array): RMSE (scalar if return_image is False, image if
return_image is True)
'''
_check_shapes(im_ref, im_comp)
L2 = np.square(im_ref - im_comp)
if self.return_image:
return np.sqrt(L2)
else:
return np.sqrt(L2.mean())

class SSIM:
'''Structural Similarity Index Measure between two images. Images must have
Expand Down
3 changes: 3 additions & 0 deletions IQM_Vis/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from IQM_Vis.metrics.IQMs import (MAE,
MSE,
RMSE,
SSIM,
MS_SSIM,
LPIPS,
Expand All @@ -23,6 +24,7 @@ def get_all_metrics():
'DISTS': DISTS(),
'LPIPS': LPIPS(),
'MAE': MAE(),
'RMSE': RMSE(),
}
return all_metrics

Expand All @@ -38,6 +40,7 @@ def get_all_metric_images():
'SSIM': SSIM(return_image=True),
# 'MS_SSIM': MS_SSIM(return_image=True),
'MAE': MAE(return_image=True),
'RMSE': RMSE(return_image=True),
}
return all_metrics

Expand Down
2 changes: 1 addition & 1 deletion IQM_Vis/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# License: BSD 3-Clause License

# Changing the version number will action GitHub to push to PyPi the new version
__version__ = '0.2.5.91'
__version__ = '0.2.5.92'

0 comments on commit 6896505

Please sign in to comment.