Skip to content

Commit

Permalink
[brief] Updates the unit tests to ensure the MAE metrics are tested p…
Browse files Browse the repository at this point in the history
…roperly.

[detailed]
  • Loading branch information
marovira committed Nov 15, 2024
1 parent 2ed2965 commit a938f8c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,21 @@ def test_mAP(self) -> None:
mAP = metrics.CalculateMAP()
self.check_almost_equal(mAP(targs, preds), functional.calculate_mAP(targs, preds))

def test_MAE(self) -> None:
def check_mae(self, scale: float) -> None:
rng.seed_rngs()

pred = torch.rand((32, 32))
gt = torch.rand((32, 32))
mae = metrics.CalculateMAE()
pred = torch.randn((1, 3, 32, 32)) * scale
gt = torch.randn((1, 3, 32, 32)) * scale
mae = metrics.CalculateMAE(scale)

self.check_almost_equal(mae(pred, gt), functional.calculate_mae_torch(pred, gt))
self.check_almost_equal(
mae(pred, gt), functional.calculate_mae_torch(pred, gt, scale)
)

pred = pred.numpy() # type: ignore[assignment]
gt = gt.numpy() # type: ignore[assignment]
self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt)) # type: ignore[arg-type]
self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt, scale)) # type: ignore[arg-type]

def test_MAE(self) -> None:
self.check_mae(1.0)
self.check_mae(255.0)

0 comments on commit a938f8c

Please sign in to comment.