From 3424ec1daf6e004f3b9ef8e338fdf701f5eafadc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BA=E9=9B=A8=E6=9D=B0?= Date: Wed, 6 Nov 2024 11:15:37 +0800 Subject: [PATCH] add new npu op roiaware_pool3d --- mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp | 86 +++++++++++++++++++ mmcv/ops/scatter_points.py | 10 +-- 2 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp b/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp new file mode 100644 index 0000000000..50706df867 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp @@ -0,0 +1,86 @@ +#include "pytorch_npu_helper.hpp" +using namespace NPU_NAME_SPACE; +using namespace std; + +void roiaware_pool3d_forward_npu(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, + Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + at::Tensor rois_cast = rois; + at::Tensor pts_cast = pts; + at::Tensor pts_feature_cast = pts_feature; + at::Tensor pooled_features_cast = pooled_features; + + auto dtype = rois.dtype(); + if (dtype == at::kHalf) { + rois_cast = rois_cast.to(at::kFloat); + pts_cast = pts_cast.to(at::kFloat); + pts_feature_cast = pts_feature_cast.to(at::kFloat); + pooled_features_cast = pooled_features_cast.to(at::kFloat); + } + + EXEC_NPU_CMD(aclnnRoiawarePool3d, rois_cast, pts_cast, pts_feature_cast, + pool_method, max_pts_each_voxel, out_x, out_y, out_z, argmax, + pts_idx_of_voxels, pooled_features_cast); + + if (dtype == at::kHalf) { + pooled_features_cast = pooled_features_cast.to(at::kHalf); + } + + pooled_features.copy_(pooled_features_cast); +} + +void roiaware_pool3d_backward_npu(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) +{ + int32_t npoints = grad_in.size(0); + + auto dtype = grad_out.dtype(); + at::Tensor grad_out_cast = grad_out; + at::Tensor grad_in_cast = grad_in; + + if (dtype == at::kHalf) { + grad_out_cast = grad_out.to(at::kFloat); + grad_in_cast = grad_in_cast.to(at::kFloat); + } + + if (pool_method == 0) { + // maxpool3d + EXEC_NPU_CMD(aclnnRoiawareMaxpool3dGrad, argmax, grad_out_cast, boxes_num, + out_x, out_y, out_z, channels, npoints, grad_in_cast); + } else if (pool_method == 1) { + // avgpool3d + EXEC_NPU_CMD(aclnnRoiawareAvgpool3dGrad, pts_idx_of_voxels, grad_out_cast, + boxes_num, out_x, out_y, out_z, channels, npoints, + max_pts_each_voxel, grad_in_cast); + } + + if (dtype == at::kHalf) { + grad_in_cast = grad_in_cast.to(at::kHalf); + } + + grad_in.copy_(grad_in_cast); +} + +void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method); + +REGISTER_NPU_IMPL(roiaware_pool3d_forward_impl, roiaware_pool3d_forward_npu); +REGISTER_NPU_IMPL(roiaware_pool3d_backward_impl, roiaware_pool3d_backward_npu); \ No newline at end of file diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index 68bd28319f..d69a87b358 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -38,15 +38,15 @@ def forward(ctx: Any, """ ctx.device = feats.device.type if ctx.device == 'npu': - import ads_c - voxel_idx = ads_c.point_to_voxel(coors, [], [], 'XYZ') - unique_res = ads_c.unique_voxel(voxel_idx) + import mx_driving._C + voxel_idx = mx_driving._C.point_to_voxel(coors, [], [], 'XYZ') + unique_res = mx_driving._C.unique_voxel(voxel_idx) num_voxels, uniqued_voxel_idx, prefix_sum, \ argsort_coor, _ = unique_res voxel_coors = \ - ads_c.voxel_to_point(uniqued_voxel_idx, [], [], 'XYZ') + mx_driving._C.voxel_to_point(uniqued_voxel_idx, [], [], 'XYZ') voxel_feats, \ - compare_mask = ads_c.npu_dynamic_scatter(feats, coors, + compare_mask = mx_driving._C.npu_dynamic_scatter(feats, coors, prefix_sum, argsort_coor, num_voxels,