Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature](mluOpRoiAlignRotatedForward): Refactor roi align rotated forward #1138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/design_docs/roi_align_rotated/roi_align_rotated.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ roi_align_rotated算子应用于FOTS网络结构中,以双线性插值的方
|是否需要支持原位 | 否 |
| 是否需要支持stride机制 | 否 |
| 是否需要支持广播 | 否 |
| 0元素检查是否直接返回 | 是 | |
| 0元素检查是否直接返回 | 是 |

### 1.2 算子功能和应用场景描述

Expand Down Expand Up @@ -208,7 +208,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward(mluOpHandle_t handle,
|---------|---------|---------|---------|
| input | output | input | output |
|---------|---------|---------|---------|

```
与前向类似,反向计算时把空间均分为4部分,保持每次处理的input和output的大小相同。
### 3.4 性能优化设计
Expand Down Expand Up @@ -252,7 +252,7 @@ RoiAlign类的算子是IO瓶颈,在一个bin中需要处理多个采样点,
6、output的HW维度需要分别与参数中的pooled_height和pooled_width保持一致。

反向:

1、指针为空防呆;

2、0元素检查防呆,VLOG(5)打印信息;
Expand Down
438 changes: 438 additions & 0 deletions docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md

Large diffs are not rendered by default.

53 changes: 33 additions & 20 deletions kernels/roi_align_rotated/roi_align_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
return MLUOP_STATUS_BAD_PARAM;
}

const int channel = features_desc->dims[3];
const int width = features_desc->dims[2];
const int height = features_desc->dims[1];
const int batch = features_desc->dims[0];
const int rois_nums = rois_desc->dims[0];
mluOpDataType_t data_type = features_desc->dtype;

PARAM_CHECK_GT(API, pooled_height, 0);
PARAM_CHECK_GT(API, pooled_width, 0);
PARAM_CHECK_GE(API, spatial_scale, 0);
Expand All @@ -117,10 +110,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
PARAM_CHECK(API, output != nullptr);
PARAM_CHECK(API, rois != nullptr);

const int channel = features_desc->dims[3];
const int width = features_desc->dims[2];
const int height = features_desc->dims[1];
const int batch = features_desc->dims[0];
const int rois_nums = rois_desc->dims[0];
VLOG(5) << "pool_height: " << pooled_height << ",pool_width: " << pooled_width
<< ",channel: " << channel << ",roi nums: " << rois_nums << ".";
VLOG(5) << "batch: " << batch << ",height: " << height << ",width: " << width
<< ".";
<< ",sample_ratio: " << sample_ratio << ".";

if (MLUOP_GEN_CASE_ON_NEW) {
GEN_CASE_START("roi_align_rotated_forward", "ROI_ALIGN_ROTATED_FORWARD");
Expand All @@ -142,18 +140,33 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
clockwise);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sample_ratio, spatial_scale,
aligned, clockwise};
mluOpRoiAlignRotatedParams roiAlignRotatedParams{
aligned, clockwise, pooled_height,
pooled_width, sample_ratio, spatial_scale};

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(handle, rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForward(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));

uint32_t sample_ratio_split = 3, channels_split = 1024;
if (handle->arch >= MLUOP_MLU590 && channel <= channels_split &&
(sample_ratio >= sample_ratio_split || sample_ratio <= 0)
) {
VLOG(5) << "[mluOpRoiAlignRotatedForwardVector] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForwardVector(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));
} else {
VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForward(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));
}

VLOG(5) << "Kernel KernelRoiAlignRotatedForward.";
GEN_CASE_END();
return MLUOP_STATUS_SUCCESS;
Expand Down Expand Up @@ -259,9 +272,9 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward(
clockwise);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sample_ratio, spatial_scale,
aligned, clockwise};
mluOpRoiAlignRotatedParams roiAlignRotatedParams{
aligned, clockwise, pooled_height,
pooled_width, sample_ratio, spatial_scale};

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
Expand Down
15 changes: 11 additions & 4 deletions kernels/roi_align_rotated/roi_align_rotated.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,29 @@
#include "mlu_op.h"

struct mluOpRoiAlignRotatedParams {
bool aligned;
bool clockwise;
int pooled_height;
int pooled_width;
int sample_ratio;
float spatial_scale;
bool aligned;
bool clockwise;
};

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
mluOpStatus_t KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams,
void *output);

mluOpStatus_t KernelRoiAlignRotatedForwardVector(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams,
void *output);

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward(
mluOpStatus_t KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *top_grad, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand Down
4 changes: 2 additions & 2 deletions kernels/roi_align_rotated/roi_align_rotated_block.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ __mlu_global__ void roiAlignRotatedBackward(
}
}

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
mluOpStatus_t KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand All @@ -474,7 +474,7 @@ mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward(
mluOpStatus_t KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *top_grad, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand Down
Loading
Loading