Skip to content

Commit

Permalink
fix: #38
Browse files Browse the repository at this point in the history
  • Loading branch information
Owen-Liuyuxuan committed Sep 27, 2021
1 parent 5a54e45 commit 82d0c2c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions visualDet3D/networks/heads/detection_3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,13 @@ def loss(self, cls_scores, reg_preds, anchors, annotations, P2s):
pos_anchor = anchor[pos_inds]
pos_alpha_score = alpha_score[pos_inds]
if self.decode_before_loss:
pos_prediction_decoded = self._decode(pos_anchor, reg_pred[pos_inds], anchors_3d_mean_std, label_index, pos_alpha_score)
pos_target_decoded = self._decode(pos_anchor, pos_bbox_targets, anchors_3d_mean_std, label_index, pos_alpha_score)

reg_loss.append((self.loss_bbox(pos_prediction_decoded, pos_target_decoded)* self.regression_weight).mean(dim=0))
pos_prediction_decoded, mask = self._decode(pos_anchor, reg_pred[pos_inds], anchor_mean_std_3d_j[pos_inds], label_index, pos_alpha_score)
pos_target_decoded, _ = self._decode(pos_anchor, pos_bbox_targets, anchor_mean_std_3d_j[pos_inds], label_index, pos_alpha_score)
reg_loss_j = self.loss_bbox(pos_prediction_decoded[mask], pos_target_decoded[mask])
alpha_loss_j = self.alpha_loss(pos_alpha_score[mask], targets_alpha_cls[mask])
loss_j = torch.cat([reg_loss_j, alpha_loss_j], dim=1) * self.regression_weight #[N, 12]
reg_loss.append(loss_j.mean(dim=0)) #[13]
number_of_positives.append(bbox_annotation.shape[0])
else:
reg_loss_j = self.loss_bbox(pos_bbox_targets, reg_pred[pos_inds])
alpha_loss_j = self.alpha_loss(pos_alpha_score, targets_alpha_cls)
Expand Down

0 comments on commit 82d0c2c

Please sign in to comment.