Skip to content

Commit

Permalink
actual torchmetric instance
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Nov 27, 2024
1 parent 02ccbfa commit 985233a
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions luxonis_train/attached_modules/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,12 @@ def __init__(
"Multiple tasks detected in self.node.tasks. Only one task is allowed."
)

if self.is_classification:
self.add_state(
"classification_cm",
default=torch.zeros(
self.n_classes, self.n_classes, dtype=torch.int64
),
dist_reduce_fx="sum",
)
if self.is_segmentation:
self.add_state(
"segmentation_cm",
default=torch.zeros(
self.n_classes, self.n_classes, dtype=torch.int64
),
dist_reduce_fx="sum",
self.metric_cm = None
if self.is_classification or self.is_segmentation:
self.metric_cm = MulticlassConfusionMatrix(
num_classes=self.n_classes
)

if self.is_detection:
self.add_state(
"detection_cm",
Expand Down Expand Up @@ -168,17 +158,14 @@ def update(
target = targets["classification"]
pred_classes = preds[0].argmax(dim=1) # [B]
target_classes = target.argmax(dim=1) # [B]
self.classification_cm += self.compute_confusion_matrix(
pred_classes, target_classes
)
self.metric_cm.update(pred_classes, target_classes)

Check failure on line 161 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View workflow job for this annotation

GitHub Actions / type-check

"update" is not a known attribute of "None" (reportOptionalMemberAccess)

if "segmentation" in predictions and "segmentation" in targets:
preds = predictions["segmentation"]
target = targets["segmentation"]
pred_masks = preds[0].argmax(dim=1) # [B, H, W]
target_masks = target.argmax(dim=1) # [B, H, W]
self.segmentation_cm += self.compute_confusion_matrix(
pred_masks.view(-1), target_masks.view(-1)
)
self.metric_cm.update(pred_masks.view(-1), target_masks.view(-1))

Check failure on line 168 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View workflow job for this annotation

GitHub Actions / type-check

"update" is not a known attribute of "None" (reportOptionalMemberAccess)

if "detection" in predictions and "detection" in targets:
preds = predictions["detection"] # type: ignore
Expand All @@ -192,10 +179,12 @@ def compute(self) -> dict[str, Tensor]:
"""Compute confusion matrices for classification, segmentation,
and detection tasks."""
results = {}
if self.is_classification:
results["classification_confusion_matrix"] = self.classification_cm
if self.is_segmentation:
results["segmentation_confusion_matrix"] = self.segmentation_cm
if self.metric_cm:
task_type = (
"classification" if self.is_classification else "segmentation"
)
results[f"{task_type}_confusion_matrix"] = self.metric_cm.compute()
self.metric_cm.reset()
if self.is_detection:
results["detection_confusion_matrix"] = self.detection_cm

Expand Down

0 comments on commit 985233a

Please sign in to comment.