diff --git a/luxonis_train/attached_modules/metrics/__init__.py b/luxonis_train/attached_modules/metrics/__init__.py index b1dc40ea..cdd0b3ac 100644 --- a/luxonis_train/attached_modules/metrics/__init__.py +++ b/luxonis_train/attached_modules/metrics/__init__.py @@ -1,4 +1,5 @@ from .base_metric import BaseMetric +from .confusion_matrix import ConfusionMatrix from .mean_average_precision import MeanAveragePrecision from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints from .object_keypoint_similarity import ObjectKeypointSimilarity @@ -14,4 +15,5 @@ "ObjectKeypointSimilarity", "Precision", "Recall", + "ConfusionMatrix", ] diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py new file mode 100644 index 00000000..079a9315 --- /dev/null +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -0,0 +1,275 @@ +from typing import Literal + +import torch +from torch import Tensor +from torchmetrics.classification import MulticlassConfusionMatrix +from torchvision.ops import box_convert, box_iou + +from luxonis_train.enums import TaskType +from luxonis_train.utils import Labels, Packet + +from .base_metric import BaseMetric + + +class ConfusionMatrix(BaseMetric[Tensor, Tensor]): + def __init__( + self, + box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", + iou_threshold: float = 0.45, + confidence_threshold: float = 0.25, + **kwargs, + ): + """Compute the confusion matrix for classification, + segmentation, and object detection tasks. + + @type box_format: Literal["xyxy", "xywh", "cxcywh"] + @param box_format: The format of the bounding boxes. Can be one + of "xyxy", "xywh", or "cxcywh". + @type iou_threshold: float + @param iou_threshold: The IoU threshold for matching predictions + to ground truth. + @type confidence_threshold: float + @param confidence_threshold: The confidence threshold for + filtering predictions. + """ + super().__init__(**kwargs) + allowed_box_formats = ("xyxy", "xywh", "cxcywh") + if box_format not in allowed_box_formats: + raise ValueError( + f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}" + ) + self.box_format = box_format + self.iou_threshold = iou_threshold + self.confidence_threshold = confidence_threshold + + self.is_classification = ( + self.node.tasks is not None + and TaskType.CLASSIFICATION in self.node.tasks + ) + self.is_detection = ( + self.node.tasks is not None + and TaskType.BOUNDINGBOX in self.node.tasks + ) + self.is_segmentation = ( + self.node.tasks is not None + and TaskType.SEGMENTATION in self.node.tasks + ) + + if ( + sum( + [ + self.is_classification, + self.is_detection, + self.is_segmentation, + ] + ) + > 1 + ): + raise ValueError( + "Multiple tasks detected in self.node.tasks. Only one task is allowed." + ) + + self.metric_cm = None + if self.is_classification or self.is_segmentation: + self.metric_cm = MulticlassConfusionMatrix( + num_classes=self.n_classes + ) + + if self.is_detection: + self.add_state( + "detection_cm", + default=torch.zeros( + self.n_classes + 1, self.n_classes + 1, dtype=torch.int64 + ), # +1 for background + dist_reduce_fx="sum", + ) + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """Prepare data for classification, segmentation, and detection + tasks. + + @type inputs: Packet[Tensor] + @param inputs: The inputs to the model. + @type labels: Labels + @param labels: The ground-truth labels. + @return: A tuple of two dictionaries: one for predictions and + one for targets. + """ + predictions = {} + targets = {} + + if self.is_detection: + out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX) + bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + bbox = bbox.to(out_bbox[0].device).clone() + bbox[..., 2:6] = box_convert(bbox[..., 2:6], "xywh", "xyxy") + scale_factors = torch.tensor( + [ + self.original_in_shape[2], + self.original_in_shape[1], + self.original_in_shape[2], + self.original_in_shape[1], + ], + device=bbox.device, + ) + bbox[..., 2:6] *= scale_factors + predictions["detection"] = out_bbox + targets["detection"] = bbox + + if self.is_classification: + prediction = self.get_input_tensors( + inputs, TaskType.CLASSIFICATION + ) + target = self.get_label(labels, TaskType.CLASSIFICATION).to( + prediction[0].device + ) + predictions["classification"] = prediction + targets["classification"] = target + + if self.is_segmentation: + prediction = self.get_input_tensors(inputs, TaskType.SEGMENTATION) + target = self.get_label(labels, TaskType.SEGMENTATION).to( + prediction[0].device + ) + predictions["segmentation"] = prediction + targets["segmentation"] = target + + return predictions, targets + + def update( + self, predictions: dict[str, Tensor], targets: dict[str, Tensor] + ) -> None: + """Update the confusion matrices for all tasks using prepared + data. + + @type predictions: dict[str, Tensor] + @param predictions: A dictionary containing predictions for all + tasks. + @type targets: dict[str, Tensor] + @param targets: A dictionary containing targets for all tasks. + """ + if "classification" in predictions and "classification" in targets: + preds = predictions["classification"] + target = targets["classification"] + pred_classes = preds[0].argmax(dim=1) # [B] + target_classes = target.argmax(dim=1) # [B] + if self.metric_cm is not None: + self.metric_cm.update(pred_classes, target_classes) + + if "segmentation" in predictions and "segmentation" in targets: + preds = predictions["segmentation"] + target = targets["segmentation"] + pred_masks = preds[0].argmax(dim=1) # [B, H, W] + target_masks = target.argmax(dim=1) # [B, H, W] + if self.metric_cm is not None: + self.metric_cm.update( + pred_masks.view(-1), target_masks.view(-1) + ) + + if "detection" in predictions and "detection" in targets: + preds = predictions["detection"] # type: ignore + target = targets["detection"] + self.detection_cm += self._compute_detection_confusion_matrix( # type: ignore + preds, # type: ignore + target, # type: ignore + ) + + def compute(self) -> dict[str, Tensor]: + """Compute confusion matrices for classification, segmentation, + and detection tasks.""" + results = {} + if self.metric_cm: + task_type = ( + "classification" if self.is_classification else "segmentation" + ) + results[f"{task_type}_confusion_matrix"] = self.metric_cm.compute() + self.metric_cm.reset() + if self.is_detection: + results["detection_confusion_matrix"] = self.detection_cm + + return results + + def _compute_detection_confusion_matrix( + self, preds: list[Tensor], targets: Tensor + ) -> Tensor: + """Compute a confusion matrix for object detection tasks. + + @type preds: list[Tensor] + @param preds: List of predictions for each image. Each tensor + has shape [N, 6] where 6 is for [x1, y1, x2, y2, score, + class] + @type targets: Tensor + @param targets: Ground truth boxes and classes. Shape [M, 6] + where first column is image index. + """ + cm = torch.zeros( + self.n_classes + 1, + self.n_classes + 1, + dtype=torch.int64, + device=preds[0].device, + ) + + for img_idx, pred in enumerate(preds): + img_targets = targets[targets[:, 0] == img_idx] + + if img_targets.shape[0] == 0: + for pred_class in pred[:, 5].int(): + cm[pred_class, self.n_classes] += 1 + continue + + if pred.shape[0] == 0: + for gt_class in img_targets[:, 1].int(): + cm[self.n_classes, gt_class] += 1 + continue + + pred = pred[pred[:, 4] > self.confidence_threshold] + pred_boxes = pred[:, :4] + pred_classes = pred[:, 5].int() + + gt_boxes = img_targets[:, 2:] + gt_classes = img_targets[:, 1].int() + + iou = box_iou(gt_boxes, pred_boxes) + iou_thresholded = iou > self.iou_threshold + + if iou_thresholded.any(): + iou_max, pred_max_idx = torch.max(iou, dim=1) + iou_gt_mask = iou_max > self.iou_threshold + gt_match_idx = torch.arange( + len(gt_boxes), device=gt_boxes.device + )[iou_gt_mask] + pred_match_idx = pred_max_idx[iou_gt_mask] + + for gt_idx, pred_idx in zip(gt_match_idx, pred_match_idx): + gt_class = gt_classes[gt_idx] + pred_class = pred_classes[pred_idx] + cm[pred_class, gt_class] += 1 + + unmatched_gt_mask = ~torch.isin( + torch.arange(len(gt_boxes), device=gt_boxes.device), + gt_match_idx, + ) + for gt_idx in torch.arange( + len(gt_boxes), device=gt_boxes.device + )[unmatched_gt_mask]: + gt_class = gt_classes[gt_idx] + cm[self.n_classes, gt_class] += 1 + + unmatched_pred_mask = ~torch.isin( + torch.arange(len(pred_boxes), device=gt_boxes.device), + pred_match_idx, + ) + for pred_idx in torch.arange( + len(pred_boxes), device=gt_boxes.device + )[unmatched_pred_mask]: + pred_class = pred_classes[pred_idx] + cm[pred_class, self.n_classes] += 1 + else: + for gt_class in gt_classes: + cm[self.n_classes, gt_class] += 1 + for pred_class in pred_classes: + cm[pred_class, self.n_classes] += 1 + + return cm diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 941cd649..baf654c3 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -151,15 +151,24 @@ def check_predefined_model(self) -> Self: def check_main_metric(self) -> Self: for metric in self.metrics: if metric.is_main_metric: + if "matrix" in metric.name.lower(): + raise ValueError( + f"Main metric cannot contain 'matrix' in its name: `{metric.name}`" + ) logger.info(f"Main metric: `{metric.name}`") return self logger.warning("No main metric specified.") if self.metrics: - metric = self.metrics[0] - metric.is_main_metric = True - name = metric.alias or metric.name - logger.info(f"Setting '{name}' as main metric.") + for metric in self.metrics: + if "matrix" not in metric.name.lower(): + metric.is_main_metric = True + name = metric.alias or metric.name + logger.info(f"Setting '{name}' as main metric.") + return self + raise ValueError( + "[Configuration Error] No valid main metric can be set as all metrics contain 'matrix' in their names." + ) else: logger.warning( "[Ignore if using predefined model] " diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 08d0066f..2cb00e7e 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -810,14 +810,21 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: logger.info("Metrics computed.") for node_name, metrics in computed_metrics.items(): for metric_name, metric_value in metrics.items(): - metric_results[node_name][metric_name] = ( - metric_value.cpu().item() - ) - self.log( - f"{mode}/metric/{node_name}/{metric_name}", - metric_value, - sync_dist=True, - ) + if "matrix" in metric_name.lower(): + self.logger.log_matrix( + matrix=metric_value.cpu().numpy(), + name=f"{mode}/metrics/{self.current_epoch}/{metric_name}", + step=self.current_epoch, + ) + else: + metric_results[node_name][metric_name] = ( + metric_value.cpu().item() + ) + self.log( + f"{mode}/metric/{node_name}/{metric_name}", + metric_value, + sync_dist=True, + ) if self.cfg.trainer.verbose: self._print_results( diff --git a/requirements-config.txt b/requirements-config.txt index 0a7b2625..f8498e7d 100644 --- a/requirements-config.txt +++ b/requirements-config.txt @@ -1 +1 @@ -luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@dev +luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@main diff --git a/requirements.txt b/requirements.txt index 5d0fcb28..5ef87b3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ blobconverter>=1.4.2 lightning>=2.4.0 -luxonis-ml[data,tracker]>=0.5.0 +luxonis-ml[data,tracker]>=0.5.1 onnx>=1.12.0 onnxruntime>=1.13.1 onnxsim>=0.4.10 diff --git a/tests/unittests/test_metrics/test_confusion_matrix.py b/tests/unittests/test_metrics/test_confusion_matrix.py new file mode 100644 index 00000000..07429d8b --- /dev/null +++ b/tests/unittests/test_metrics/test_confusion_matrix.py @@ -0,0 +1,58 @@ +import torch + +from luxonis_train.attached_modules.metrics.confusion_matrix import ( + ConfusionMatrix, +) +from luxonis_train.enums import TaskType +from luxonis_train.nodes import BaseNode + + +def test_compute_detection_confusion_matrix_specific_case(): + class DummyNodeDetection(BaseNode): + tasks = [TaskType.BOUNDINGBOX] + + def forward(self, _): + pass + + metric = ConfusionMatrix( + node=DummyNodeDetection(n_classes=3), iou_threshold=0.5 + ) + + preds = [torch.empty((0, 6)) for _ in range(3)] + preds.append( + torch.tensor( + [ + [10, 20, 30, 50, 0.8, 2], + [10, 21, 30, 50, 0.8, 1], + [10, 20, 30, 50, 0.8, 1], + [51, 61, 71, 78, 0.9, 2], + ] + ) + ) + + # Targets: ground truth for 4 images + targets = torch.tensor( + [ + [3, 1, 10, 20, 30, 50], + [0, 1, 10, 20, 30, 40], + [1, 2, 50, 60, 70, 80], + [2, 2, 10, 60, 70, 80], + [3, 2, 50, 60, 70, 80], + ] + ) + + expected_cm = torch.tensor( + [ + [0, 0, 0, 0], + [0, 0, 0, 2], + [0, 1, 1, 0], + [0, 1, 2, 0], + ], + dtype=torch.int64, + ) + + computed_cm = metric._compute_detection_confusion_matrix(preds, targets) + + assert torch.equal( + computed_cm, expected_cm + ), f"Expected {expected_cm}, but got {computed_cm}"