diff --git a/test/test_metrics.py b/test/test_metrics.py index bada9dd..8f7842f 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -62,15 +62,21 @@ def test_mAP(self) -> None: mAP = metrics.CalculateMAP() self.check_almost_equal(mAP(targs, preds), functional.calculate_mAP(targs, preds)) - def test_MAE(self) -> None: + def check_mae(self, scale: float) -> None: rng.seed_rngs() - pred = torch.rand((32, 32)) - gt = torch.rand((32, 32)) - mae = metrics.CalculateMAE() + pred = torch.randn((1, 3, 32, 32)) * scale + gt = torch.randn((1, 3, 32, 32)) * scale + mae = metrics.CalculateMAE(scale) - self.check_almost_equal(mae(pred, gt), functional.calculate_mae_torch(pred, gt)) + self.check_almost_equal( + mae(pred, gt), functional.calculate_mae_torch(pred, gt, scale) + ) pred = pred.numpy() # type: ignore[assignment] gt = gt.numpy() # type: ignore[assignment] - self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt)) # type: ignore[arg-type] + self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt, scale)) # type: ignore[arg-type] + + def test_MAE(self) -> None: + self.check_mae(1.0) + self.check_mae(255.0)