Skip to content

Commit

Permalink
points_in_boxes_all和points_in_boxes_part的mmcv兼容npu判断 (#3189)
Browse files Browse the repository at this point in the history
* points_in_boxes_all和points_in_boxes_part的mmcv兼容npu判断

* Update points_in_boxes.py

去除判断
  • Loading branch information
ZrBac authored Nov 25, 2024
1 parent e1aab12 commit 03ce920
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mmcv/ops/points_in_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
if points.device.type != 'npu':
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
else:
boxes[:, :, 2] += boxes[:, :, 5] / 2.0

ext_module.points_in_boxes_part_forward(boxes.contiguous(),
points.contiguous(),
Expand Down Expand Up @@ -127,8 +130,9 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
if points.device.type != 'npu':
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)

ext_module.points_in_boxes_all_forward(boxes.contiguous(),
points.contiguous(),
Expand Down

0 comments on commit 03ce920

Please sign in to comment.