Skip to content

Commit

Permalink
[Feature](mluOpRoiAlignRotatedForward): bin_cycle vector
Browse files Browse the repository at this point in the history
  • Loading branch information
chqy99 committed Nov 5, 2024
1 parent 51000c0 commit bad8afb
Show file tree
Hide file tree
Showing 7 changed files with 1,111 additions and 39 deletions.
6 changes: 3 additions & 3 deletions docs/design_docs/roi_align_rotated/roi_align_rotated.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ roi_align_rotated算子应用于FOTS网络结构中,以双线性插值的方
|是否需要支持原位 ||
| 是否需要支持stride机制 ||
| 是否需要支持广播 ||
| 0元素检查是否直接返回 || |
| 0元素检查是否直接返回 ||

### 1.2 算子功能和应用场景描述

Expand Down Expand Up @@ -208,7 +208,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward(mluOpHandle_t handle,
|---------|---------|---------|---------|
| input | output | input | output |
|---------|---------|---------|---------|

```
与前向类似,反向计算时把空间均分为4部分,保持每次处理的input和output的大小相同。
### 3.4 性能优化设计
Expand Down Expand Up @@ -252,7 +252,7 @@ RoiAlign类的算子是IO瓶颈,在一个bin中需要处理多个采样点,
6、output的HW维度需要分别与参数中的pooled_height和pooled_width保持一致。

反向:

1、指针为空防呆;

2、0元素检查防呆,VLOG(5)打印信息;
Expand Down
435 changes: 435 additions & 0 deletions docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md

Large diffs are not rendered by default.

56 changes: 36 additions & 20 deletions kernels/roi_align_rotated/roi_align_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
return MLUOP_STATUS_BAD_PARAM;
}

const int channel = features_desc->dims[3];
const int width = features_desc->dims[2];
const int height = features_desc->dims[1];
const int batch = features_desc->dims[0];
const int rois_nums = rois_desc->dims[0];
mluOpDataType_t data_type = features_desc->dtype;

PARAM_CHECK_GT(API, pooled_height, 0);
PARAM_CHECK_GT(API, pooled_width, 0);
PARAM_CHECK_GE(API, spatial_scale, 0);
Expand All @@ -117,10 +110,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
PARAM_CHECK(API, output != nullptr);
PARAM_CHECK(API, rois != nullptr);

const int channel = features_desc->dims[3];
const int width = features_desc->dims[2];
const int height = features_desc->dims[1];
const int batch = features_desc->dims[0];
const int rois_nums = rois_desc->dims[0];
VLOG(5) << "pool_height: " << pooled_height << ",pool_width: " << pooled_width
<< ",channel: " << channel << ",roi nums: " << rois_nums << ".";
VLOG(5) << "batch: " << batch << ",height: " << height << ",width: " << width
<< ".";
<< ",sample_ratio: " << sample_ratio << ".";

if (MLUOP_GEN_CASE_ON_NEW) {
GEN_CASE_START("roi_align_rotated_forward", "ROI_ALIGN_ROTATED_FORWARD");
Expand All @@ -142,18 +140,36 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(
clockwise);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sample_ratio, spatial_scale,
aligned, clockwise};
mluOpRoiAlignRotatedParams roiAlignRotatedParams{
aligned, clockwise, pooled_height,
pooled_width, sample_ratio, spatial_scale};

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(handle, rois_nums * pooled_height * pooled_width, &k_dim, &k_type);
VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForward(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));

uint32_t sample_ratio_split = 3, channels_split = 1;
if (handle->arch >= MLUOP_MLU590) {
channels_split = 1024;
}

if ((sample_ratio >= sample_ratio_split || sample_ratio <= 0) &&
channel <= channels_split) {
VLOG(5) << "[mluOpRoiAlignRotatedForwardVector] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForwardVector(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));
} else {
VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc["
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "].";
CHECK_RETURN(API, KernelRoiAlignRotatedForward(
k_dim, k_type, handle->queue, features_desc->dtype,
features, rois, batch, height, width, channel,
rois_nums, roiAlignRotatedParams, output));
}

VLOG(5) << "Kernel KernelRoiAlignRotatedForward.";
GEN_CASE_END();
return MLUOP_STATUS_SUCCESS;
Expand Down Expand Up @@ -259,9 +275,9 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward(
clockwise);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width,
sample_ratio, spatial_scale,
aligned, clockwise};
mluOpRoiAlignRotatedParams roiAlignRotatedParams{
aligned, clockwise, pooled_height,
pooled_width, sample_ratio, spatial_scale};

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
Expand Down
15 changes: 11 additions & 4 deletions kernels/roi_align_rotated/roi_align_rotated.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,29 @@
#include "mlu_op.h"

struct mluOpRoiAlignRotatedParams {
bool aligned;
bool clockwise;
int pooled_height;
int pooled_width;
int sample_ratio;
float spatial_scale;
bool aligned;
bool clockwise;
};

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
mluOpStatus_t KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams,
void *output);

mluOpStatus_t KernelRoiAlignRotatedForwardVector(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams,
void *output);

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward(
mluOpStatus_t KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *top_grad, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand Down
4 changes: 2 additions & 2 deletions kernels/roi_align_rotated/roi_align_rotated_block.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ __mlu_global__ void roiAlignRotatedBackward(
}
}

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
mluOpStatus_t KernelRoiAlignRotatedForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *features, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand All @@ -474,7 +474,7 @@ mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward(
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward(
mluOpStatus_t KernelRoiAlignRotatedBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpDataType_t d_type, const void *top_grad, const void *rois,
const int batch, const int height, const int width, const int channel,
Expand Down
Loading

0 comments on commit bad8afb

Please sign in to comment.