Skip to content

Commit

Permalink
Merge pull request #77 from JYYCaN/patch-1
Browse files Browse the repository at this point in the history
scatter points bug fix
  • Loading branch information
hust17yixuan authored Nov 7, 2024
2 parents 7ee5a41 + eedad49 commit 1c2c238
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def backward(ctx: Any,
grad_voxel_feats: torch.Tensor,
grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple:
if ctx.device == 'npu':
import ads_c
import mx_driving._C
prefix_sum, argsort_coor, compare_mask = ctx.saved_tensors
grad_point_feats = torch.zeros(
ctx.feats_shape,
dtype=grad_voxel_feats.dtype,
device=grad_voxel_feats.device)
ads_c.npu_dynamic_scatter_grad(grad_point_feats,
mx_driving._C.npu_dynamic_scatter_grad(grad_point_feats,
grad_voxel_feats.contiguous(),
prefix_sum, argsort_coor,
compare_mask, ctx.reduce_type)
Expand Down

0 comments on commit 1c2c238

Please sign in to comment.