diff --git a/visualDet3D/networks/heads/detection_3d_head.py b/visualDet3D/networks/heads/detection_3d_head.py index 8e66da5..6e950b2 100644 --- a/visualDet3D/networks/heads/detection_3d_head.py +++ b/visualDet3D/networks/heads/detection_3d_head.py @@ -465,10 +465,13 @@ def loss(self, cls_scores, reg_preds, anchors, annotations, P2s): pos_anchor = anchor[pos_inds] pos_alpha_score = alpha_score[pos_inds] if self.decode_before_loss: - pos_prediction_decoded = self._decode(pos_anchor, reg_pred[pos_inds], anchors_3d_mean_std, label_index, pos_alpha_score) - pos_target_decoded = self._decode(pos_anchor, pos_bbox_targets, anchors_3d_mean_std, label_index, pos_alpha_score) - - reg_loss.append((self.loss_bbox(pos_prediction_decoded, pos_target_decoded)* self.regression_weight).mean(dim=0)) + pos_prediction_decoded, mask = self._decode(pos_anchor, reg_pred[pos_inds], anchor_mean_std_3d_j[pos_inds], label_index, pos_alpha_score) + pos_target_decoded, _ = self._decode(pos_anchor, pos_bbox_targets, anchor_mean_std_3d_j[pos_inds], label_index, pos_alpha_score) + reg_loss_j = self.loss_bbox(pos_prediction_decoded[mask], pos_target_decoded[mask]) + alpha_loss_j = self.alpha_loss(pos_alpha_score[mask], targets_alpha_cls[mask]) + loss_j = torch.cat([reg_loss_j, alpha_loss_j], dim=1) * self.regression_weight #[N, 12] + reg_loss.append(loss_j.mean(dim=0)) #[13] + number_of_positives.append(bbox_annotation.shape[0]) else: reg_loss_j = self.loss_bbox(pos_bbox_targets, reg_pred[pos_inds]) alpha_loss_j = self.alpha_loss(pos_alpha_score, targets_alpha_cls)